mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
rename and fix
This commit is contained in:
@@ -45,45 +45,14 @@ dataloader = torch.utils.data.DataLoader(
|
|||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
batch = pre_processor(batch)
|
batch = pre_processor(batch)
|
||||||
|
|
||||||
|
# Test training forward pass
|
||||||
policy.train()
|
policy.train()
|
||||||
# run inference
|
|
||||||
# action = policy.select_action(batch)
|
|
||||||
loss, loss_dict = policy.forward(batch)
|
loss, loss_dict = policy.forward(batch)
|
||||||
# import requests
|
print(f"Training loss: {loss_dict}")
|
||||||
# from PIL import Image
|
|
||||||
# from transformers import AutoProcessor
|
|
||||||
# model = policy.model.paligemma_with_expert.paligemma
|
|
||||||
# model = model.to(device="cuda", dtype=torch.bfloat16)
|
|
||||||
# model.eval()
|
|
||||||
# prompt = "Describe this image."
|
|
||||||
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
|
||||||
# image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
# processor = AutoProcessor.from_pretrained(
|
|
||||||
# "google/paligemma-3b-pt-224",
|
|
||||||
# )
|
|
||||||
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
|
|
||||||
# print("generating...")
|
|
||||||
# output = model.generate(
|
|
||||||
# **inputs,
|
|
||||||
# max_new_tokens=50,
|
|
||||||
# use_cache=True, # default dynamic cache
|
|
||||||
# )
|
|
||||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
|
||||||
|
|
||||||
|
# Test inference
|
||||||
# # other model
|
policy.eval()
|
||||||
# from transformers import PaliGemmaForConditionalGeneration
|
with torch.no_grad():
|
||||||
# model = PaliGemmaForConditionalGeneration.from_pretrained(
|
actions = policy.predict_action_chunk(batch)
|
||||||
# "google/paligemma2-3b-pt-224",
|
print(f"Predicted actions shape: {actions.shape}")
|
||||||
# torch_dtype=torch.bfloat16,
|
|
||||||
# device_map="auto",
|
|
||||||
# )
|
|
||||||
# model.eval()
|
|
||||||
# print("generating...")
|
|
||||||
# output = model.generate(
|
|
||||||
# **inputs,
|
|
||||||
# max_new_tokens=100,
|
|
||||||
# use_cache=True, # default dynamic cache
|
|
||||||
# )
|
|
||||||
# print("Model 2 output:")
|
|
||||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
|
||||||
@@ -48,10 +48,10 @@ from lerobot.utils.constants import (
|
|||||||
ACTION,
|
ACTION,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
OBS_LANGUAGE_PROMPT_TOKENS,
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
OBS_LANGUAGE_PROMPT_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS,
|
OBS_LANGUAGE_TARGET_TOKENS,
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK,
|
OBS_LANGUAGE_TARGET_ATTENTION_MASK,
|
||||||
OPENPI_ATTENTION_MASK_VALUE,
|
OPENPI_ATTENTION_MASK_VALUE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -609,21 +609,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
return time.to(dtype=torch.float32, device=device)
|
return time.to(dtype=torch.float32, device=device)
|
||||||
|
|
||||||
def embed_prefix(
|
def embed_prefix(
|
||||||
self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks
|
self, images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks=None
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||||
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
"""Embed images with SigLIP, prompt tokens, and optionally target tokens with embedding layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
images: List of image tensors
|
images: List of image tensors
|
||||||
img_masks: List of image masks
|
img_masks: List of image masks
|
||||||
tokens: Language instruction tokens
|
prompt_tokens: Prompt tokens (input for generation)
|
||||||
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
target_tokens: Target tokens to predict (can be None for inference)
|
||||||
masks: Attention masks for tokens
|
prompt_masks: Attention masks for prompt tokens
|
||||||
|
target_masks: Attention masks for target tokens
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)]
|
embs: Concatenated embeddings [images, prompt_tokens, (target_tokens if provided)]
|
||||||
pad_masks: Padding masks
|
pad_masks: Padding masks
|
||||||
att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided)
|
att_masks: Attention masks (with causal masking for target prediction if target_tokens provided)
|
||||||
total_T_images: Total number of image tokens
|
total_T_images: Total number of image tokens
|
||||||
"""
|
"""
|
||||||
embs = []
|
embs = []
|
||||||
@@ -645,36 +646,36 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
|
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
|
||||||
total_T_images += num_img_embs
|
total_T_images += num_img_embs
|
||||||
|
|
||||||
# Process language instruction tokens
|
# Process prompt tokens
|
||||||
def lang_embed_func(tokens):
|
def prompt_embed_func(prompt_tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
prompt_emb = self.paligemma_with_expert.embed_language_tokens(prompt_tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
prompt_emb_dim = prompt_emb.shape[-1]
|
||||||
return lang_emb * math.sqrt(lang_emb_dim)
|
return prompt_emb * math.sqrt(prompt_emb_dim)
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
prompt_emb = self._apply_checkpoint(prompt_embed_func, prompt_tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(prompt_emb)
|
||||||
pad_masks.append(masks)
|
pad_masks.append(prompt_masks)
|
||||||
|
|
||||||
num_lang_embs = lang_emb.shape[1]
|
num_prompt_embs = prompt_emb.shape[1]
|
||||||
att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens)
|
att_masks += [0] * num_prompt_embs # Prompt tokens can attend to all previous tokens (images + prompt)
|
||||||
|
|
||||||
# Process subtask tokens if provided (these are predicted, so use causal masking)
|
# Process target tokens if provided (these are predicted, so use causal masking)
|
||||||
if subtask_tokens is not None:
|
if target_tokens is not None:
|
||||||
def subtask_embed_func(subtask_tokens):
|
def target_embed_func(target_tokens):
|
||||||
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
|
target_emb = self.paligemma_with_expert.embed_language_tokens(target_tokens)
|
||||||
subtask_emb_dim = subtask_emb.shape[-1]
|
target_emb_dim = target_emb.shape[-1]
|
||||||
return subtask_emb * math.sqrt(subtask_emb_dim)
|
return target_emb * math.sqrt(target_emb_dim)
|
||||||
|
|
||||||
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
|
target_emb = self._apply_checkpoint(target_embed_func, target_tokens)
|
||||||
embs.append(subtask_emb)
|
embs.append(target_emb)
|
||||||
|
|
||||||
# Create subtask pad masks (non-zero tokens are valid)
|
# Create target pad masks (non-zero tokens are valid)
|
||||||
pad_masks.append(subtask_masks)
|
pad_masks.append(target_masks)
|
||||||
|
|
||||||
num_subtask_embs = subtask_emb.shape[1]
|
num_target_embs = target_emb.shape[1]
|
||||||
# Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens,
|
# Causal masking for target tokens: each target token can attend to images, all prompt tokens,
|
||||||
# and previous subtask tokens
|
# and previous target tokens
|
||||||
att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens
|
att_masks += [1] * num_target_embs # Use 1 for causal attention on target tokens
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
@@ -732,17 +733,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions)
|
def forward(self, images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions, noise=None, time=None) -> Tensor:
|
||||||
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor:
|
|
||||||
"""Do a full training forward pass and compute the loss.
|
"""Do a full training forward pass and compute the loss.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
images: List of image tensors
|
images: List of image tensors
|
||||||
img_masks: List of image masks
|
img_masks: List of image masks
|
||||||
high_level_task: Instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
prompt_tokens: Prompt tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
|
||||||
high_level_task_masks: Attention masks for high_level_task
|
prompt_masks: Attention masks for prompt_tokens
|
||||||
subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup")
|
target_tokens: Target tokens to predict (e.g., tokens for "pick up the cup")
|
||||||
subtask_masks: Attention masks for subtask_tokens
|
target_masks: Attention masks for target_tokens
|
||||||
actions: Ground truth actions
|
actions: Ground truth actions
|
||||||
noise: Optional noise for flow matching
|
noise: Optional noise for flow matching
|
||||||
time: Optional time for flow matching
|
time: Optional time for flow matching
|
||||||
@@ -757,20 +757,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
|
|
||||||
# Embed prefix (images + high_level_task + subtask_tokens)
|
# Embed prefix (images + prompt_tokens + target_tokens)
|
||||||
# Use high_level_task (prompt WITHOUT subtask) + subtask_tokens to predict
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks
|
images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks
|
||||||
)
|
)
|
||||||
|
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||||
|
|
||||||
# Prepare attention masks for prefix-only pass (for subtask token prediction)
|
# Prepare attention masks for prefix-only pass (for target token prediction)
|
||||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||||
|
|
||||||
# prefix-only transformer run for subtask token prediction
|
# prefix-only transformer run for target token prediction
|
||||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||||
attention_mask=att_2d_prefix_4d,
|
attention_mask=att_2d_prefix_4d,
|
||||||
position_ids=position_ids_prefix,
|
position_ids=position_ids_prefix,
|
||||||
@@ -780,37 +779,33 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
adarms_cond=[None, None],
|
adarms_cond=[None, None],
|
||||||
)
|
)
|
||||||
|
|
||||||
# LM HEAD → SUBTASK LOGITS
|
# LM HEAD → TARGET LOGITS
|
||||||
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_high_level_task + T_subtask
|
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_prompt + T_target
|
||||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||||
|
|
||||||
# Extract logits for subtask token prediction
|
# Extract logits for target token prediction (shifted by 1 for autoregressive training)
|
||||||
# In autoregressive modeling, output at position i predicts token at position i+1
|
# Position i predicts token i+1, so we take logits from positions before target tokens:
|
||||||
# So we take logits from one position earlier:
|
# - Position (start_index-1) (last prompt token) predicts target_tokens[0]
|
||||||
# - Position (start_index-1) (last high_level_task token) predicts subtask_tokens[0]
|
# - Position (start_index) (first target token) predicts target_tokens[1], etc.
|
||||||
# - Position (start_index) (first subtask token) predicts subtask_tokens[1], etc.
|
T_prompt = prompt_tokens.size(1)
|
||||||
T_high_level_task = high_level_task.size(1)
|
T_target = target_tokens.size(1)
|
||||||
T_subtask = subtask_tokens.size(1)
|
start_index = total_T_images + T_prompt
|
||||||
start_index = total_T_images + T_high_level_task
|
end_index = start_index + T_target
|
||||||
end_index = start_index + T_subtask
|
logits_target = logits[:, start_index-1:end_index-1, :] # (B, T_target, vocab)
|
||||||
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
|
|
||||||
|
|
||||||
targets = subtask_tokens # (B, T_subtask)
|
|
||||||
# Compute cross-entropy loss
|
# Compute cross-entropy loss
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||||
# Reshape for loss computation
|
# Reshape for loss computation
|
||||||
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab)
|
logits_flat = logits_target.reshape(-1, logits_target.size(-1)) # (B*T_target, vocab)
|
||||||
targets_flat = targets.reshape(-1) # (B*T_subtask)
|
targets_flat = target_tokens.reshape(-1) # (B*T_target)
|
||||||
|
|
||||||
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask)
|
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_target)
|
||||||
loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask)
|
loss_per_token = loss_per_token.reshape(target_tokens.shape) # (B, T_target)
|
||||||
|
|
||||||
# Apply mask and compute mean loss over valid tokens
|
# Apply mask and compute mean loss over valid tokens
|
||||||
masked_loss = loss_per_token * subtask_masks.float()
|
masked_loss = loss_per_token * target_masks.float()
|
||||||
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
target_loss = masked_loss.sum() / target_masks.sum().clamp(min=1)
|
||||||
|
|
||||||
breakpoint()
|
|
||||||
# Convert embeddings to bfloat16 if needed for the model
|
# Convert embeddings to bfloat16 if needed for the model
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
@@ -819,7 +814,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks
|
# Concatenate prefix (images + prompt_tokens + target_tokens) and suffix (actions) masks
|
||||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||||
|
|
||||||
@@ -856,27 +851,26 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"flow_loss": fm_loss,
|
"flow_loss": fm_loss,
|
||||||
"subtask_loss": subtask_loss,
|
"target_loss": target_loss,
|
||||||
"loss": 10 * fm_loss.mean() + subtask_loss,
|
"loss": 10 * fm_loss.mean() + target_loss,
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _generate_subtask_tokens(
|
def _generate_target_tokens(
|
||||||
self, images, img_masks, tokens, masks, tokenizer, max_length, device
|
self, images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_length, device
|
||||||
):
|
):
|
||||||
"""Generate subtask tokens autoregressively using next token prediction."""
|
"""Generate target tokens autoregressively using next token prediction."""
|
||||||
bsize = tokens.shape[0]
|
bsize = prompt_tokens.shape[0]
|
||||||
|
|
||||||
# Get lm_head for token generation
|
# Get lm_head for token generation
|
||||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||||
|
|
||||||
# Embed prefix without subtask tokens first
|
# Embed prefix without target tokens first
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||||
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None
|
images, img_masks, prompt_tokens, target_tokens=None, prompt_masks=prompt_masks, target_masks=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize generated tokens list - start with BOS token or first token after instruction
|
# Initialize generated tokens list
|
||||||
# For PaliGemma, we'll start generation and accumulate tokens
|
|
||||||
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
|
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
|
||||||
|
|
||||||
for t in range(max_length):
|
for t in range(max_length):
|
||||||
@@ -912,7 +906,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
# Embed the generated token and append to prefix
|
# Embed the generated token and append to prefix
|
||||||
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
|
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
|
||||||
breakpoint()
|
|
||||||
|
|
||||||
def next_token_embed_func(next_token_unsqueezed):
|
def next_token_embed_func(next_token_unsqueezed):
|
||||||
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
|
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
|
||||||
@@ -938,20 +931,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self,
|
self,
|
||||||
images,
|
images,
|
||||||
img_masks,
|
img_masks,
|
||||||
tokens,
|
prompt_tokens,
|
||||||
masks,
|
prompt_masks,
|
||||||
noise=None,
|
noise=None,
|
||||||
num_steps=None,
|
num_steps=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
max_subtask_tokens=50,
|
max_target_tokens=50,
|
||||||
**kwargs: Unpack[ActionSelectKwargs],
|
**kwargs: Unpack[ActionSelectKwargs],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Do a full inference forward and compute the action."""
|
"""Do a full inference forward and compute the action."""
|
||||||
if num_steps is None:
|
if num_steps is None:
|
||||||
num_steps = self.config.num_inference_steps
|
num_steps = self.config.num_inference_steps
|
||||||
|
|
||||||
bsize = tokens.shape[0]
|
bsize = prompt_tokens.shape[0]
|
||||||
device = tokens.device
|
device = prompt_tokens.device
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
# Sample noise with padded dimension as expected by action_in_proj
|
# Sample noise with padded dimension as expected by action_in_proj
|
||||||
@@ -962,26 +955,28 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
) # Use config max_action_dim for internal processing
|
) # Use config max_action_dim for internal processing
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device)
|
||||||
|
|
||||||
# Generate subtask tokens autoregressively during inference
|
# Generate target tokens autoregressively during inference (if tokenizer provided)
|
||||||
generated_subtask_tokens = None
|
generated_target_tokens = None
|
||||||
|
target_masks = None
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
generated_subtask_tokens = self._generate_subtask_tokens(
|
generated_target_tokens = self._generate_target_tokens(
|
||||||
images, img_masks, tokens, masks, tokenizer, max_subtask_tokens, device
|
images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_target_tokens, device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Decode and print the generated subtask tokens
|
# Decode and print the generated target tokens
|
||||||
for i in range(bsize):
|
for i in range(bsize):
|
||||||
# Remove padding tokens (0) and special tokens
|
# Remove padding tokens (0) and special tokens
|
||||||
valid_tokens = generated_subtask_tokens[i][generated_subtask_tokens[i] != 0]
|
valid_tokens = generated_target_tokens[i][generated_target_tokens[i] != 0]
|
||||||
decoded_text = tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
decoded_text = tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||||
print(f"[Inference] Generated subtask {i}: {decoded_text}")
|
print(f"[Inference] Generated target {i}: {decoded_text}")
|
||||||
|
|
||||||
# Create mask for generated tokens (all valid)
|
# Create mask for generated tokens (all valid where token != 0)
|
||||||
subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool)
|
target_masks = generated_target_tokens != 0
|
||||||
|
|
||||||
# During inference, we don't have subtask_tokens yet, so pass None
|
# Embed prefix with prompt and optionally generated target tokens
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
|
||||||
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks
|
images, img_masks, prompt_tokens, target_tokens=generated_target_tokens,
|
||||||
|
prompt_masks=prompt_masks, target_masks=target_masks
|
||||||
)
|
)
|
||||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
@@ -1416,13 +1411,13 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
# Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask
|
# Use prompt tokens (WITHOUT target) for inference - we'll generate the target
|
||||||
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
|
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
|
||||||
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
|
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
|
||||||
breakpoint()
|
|
||||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||||
actions = self.model.sample_actions(
|
actions = self.model.sample_actions(
|
||||||
images, img_masks, high_level_task, high_level_task_masks,
|
images, img_masks, prompt_tokens, prompt_masks,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
@@ -1438,35 +1433,24 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
|
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
|
||||||
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
|
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
|
||||||
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK}"]
|
target_tokens, target_masks = batch[f"{OBS_LANGUAGE_TARGET_TOKENS}"], batch[f"{OBS_LANGUAGE_TARGET_ATTENTION_MASK}"]
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
# Decode and print ground truth subtask tokens during training
|
# Compute loss
|
||||||
if self.tokenizer is not None and self.training:
|
# prompt_tokens = instruction tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
|
||||||
bsize = subtask_tokens.shape[0]
|
# target_tokens = target tokens to predict (e.g., "pick up the cup")
|
||||||
for i in range(bsize):
|
loss_dict = self.model.forward(images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions)
|
||||||
# Remove padding tokens (0) and special tokens
|
|
||||||
valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
|
|
||||||
if len(valid_tokens) > 0:
|
|
||||||
decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
|
||||||
print(f"[Training] Ground truth subtask {i}: {decoded_text}")
|
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
|
||||||
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
|
||||||
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
|
|
||||||
loss_dict = self.model.forward(images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions)
|
|
||||||
|
|
||||||
# Extract the total loss
|
# Extract the total loss
|
||||||
loss = loss_dict["loss"]
|
loss = loss_dict["loss"]
|
||||||
|
|
||||||
breakpoint()
|
|
||||||
# Prepare detailed loss dictionary for logging
|
# Prepare detailed loss dictionary for logging
|
||||||
detailed_loss_dict = {
|
detailed_loss_dict = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
||||||
"subtask_loss": loss_dict["subtask_loss"].item(),
|
"target_loss": loss_dict["target_loss"].item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return loss, detailed_loss_dict
|
return loss, detailed_loss_dict
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
|||||||
|
|
||||||
max_state_dim: int = 32
|
max_state_dim: int = 32
|
||||||
task_key: str = "task"
|
task_key: str = "task"
|
||||||
high_level_task_key: str = "user_prompt"
|
prompt_key: str = "prompt"
|
||||||
subtask_only_key: str = "subtask"
|
target_key: str = "target"
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
transition = transition.copy()
|
transition = transition.copy()
|
||||||
@@ -67,7 +67,7 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
|||||||
if tasks is None:
|
if tasks is None:
|
||||||
raise ValueError("No task found in complementary data")
|
raise ValueError("No task found in complementary data")
|
||||||
|
|
||||||
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.high_level_task_key)
|
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("user_prompt")
|
||||||
|
|
||||||
# TODO: check if this necessary
|
# TODO: check if this necessary
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
@@ -86,36 +86,27 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
|||||||
for high_level_task in high_level_tasks:
|
for high_level_task in high_level_tasks:
|
||||||
cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " "))
|
cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " "))
|
||||||
|
|
||||||
# Process low level tasks with state information
|
# Process tasks to create prompts (input) and targets (what to predict)
|
||||||
low_level_prompts = []
|
prompts = [] # Input prompts ending with "Subtask:"
|
||||||
subtask_only_prompts = [] # Store only the subtask text for prediction
|
targets = [] # Target text to predict (the subtask)
|
||||||
for i, task in enumerate(tasks):
|
for i, task in enumerate(tasks):
|
||||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||||
state_str = " ".join(map(str, discretized_states[i]))
|
state_str = " ".join(map(str, discretized_states[i]))
|
||||||
|
|
||||||
# Store only the subtask text (used as prediction target)
|
# Store the subtask text as target for prediction
|
||||||
subtask_only_prompts.append(cleaned_text)
|
targets.append(cleaned_text)
|
||||||
|
|
||||||
if cleaned_high_level_tasks:
|
if cleaned_high_level_tasks:
|
||||||
cleaned_high_level_task = cleaned_high_level_tasks[i]
|
cleaned_high_level_task = cleaned_high_level_tasks[i]
|
||||||
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
|
# Prompt ends with "Subtask:" - model will predict the target
|
||||||
|
prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
|
||||||
else:
|
else:
|
||||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
raise ValueError("No high level tasks found")
|
||||||
|
|
||||||
low_level_prompts.append(full_prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts
|
transition[TransitionKey.COMPLEMENTARY_DATA][self.prompt_key] = prompts
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_only_key] = subtask_only_prompts
|
transition[TransitionKey.COMPLEMENTARY_DATA][self.target_key] = targets
|
||||||
|
|
||||||
# Process high level tasks without state information (if available)
|
|
||||||
if high_level_tasks is not None:
|
|
||||||
high_level_prompts = []
|
|
||||||
for i, cleaned_high_level_task in enumerate(cleaned_high_level_tasks):
|
|
||||||
state_str = " ".join(map(str, discretized_states[i]))
|
|
||||||
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
|
|
||||||
high_level_prompts.append(full_prompt)
|
|
||||||
|
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.high_level_task_key] = high_level_prompts
|
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
def transform_features(
|
def transform_features(
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ import torch
|
|||||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
OBS_LANGUAGE_PROMPT_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
OBS_LANGUAGE_PROMPT_TOKENS,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS,
|
OBS_LANGUAGE_TARGET_TOKENS,
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK,
|
OBS_LANGUAGE_TARGET_ATTENTION_MASK,
|
||||||
)
|
)
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
@@ -59,8 +59,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
||||||
token IDs and attention mask to the `observation` dictionary.
|
token IDs and attention mask to the `observation` dictionary.
|
||||||
|
|
||||||
Optionally, this step can also tokenize a high-level task (e.g., user prompt) and/or
|
Optionally, this step can also tokenize a prompt (input for generation) and/or
|
||||||
a subtask if present in the complementary data, creating separate tokenized observations.
|
a target (text to predict) if present in the complementary data, creating separate tokenized observations.
|
||||||
|
|
||||||
Requires the `transformers` library to be installed.
|
Requires the `transformers` library to be installed.
|
||||||
|
|
||||||
@@ -69,8 +69,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||||
max_length: The maximum length to pad or truncate sequences to.
|
max_length: The maximum length to pad or truncate sequences to.
|
||||||
task_key: The key in `complementary_data` where the task string is stored.
|
task_key: The key in `complementary_data` where the task string is stored.
|
||||||
high_level_task_key: The key in `complementary_data` where the high-level task (user prompt) is stored.
|
prompt_key: The key in `complementary_data` where the prompt (input for generation) is stored.
|
||||||
subtask_key: The key in `complementary_data` where the subtask string is stored.
|
target_key: The key in `complementary_data` where the target (text to predict) is stored.
|
||||||
padding_side: The side to pad on ('left' or 'right').
|
padding_side: The side to pad on ('left' or 'right').
|
||||||
padding: The padding strategy ('max_length', 'longest', etc.).
|
padding: The padding strategy ('max_length', 'longest', etc.).
|
||||||
truncation: Whether to truncate sequences longer than `max_length`.
|
truncation: Whether to truncate sequences longer than `max_length`.
|
||||||
@@ -81,8 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
||||||
max_length: int = 512
|
max_length: int = 512
|
||||||
task_key: str = "task"
|
task_key: str = "task"
|
||||||
high_level_task_key: str = "user_prompt"
|
prompt_key: str = "prompt"
|
||||||
subtask_key: str = "subtask"
|
target_key: str = "target"
|
||||||
padding_side: str = "right"
|
padding_side: str = "right"
|
||||||
padding: str = "max_length"
|
padding: str = "max_length"
|
||||||
truncation: bool = True
|
truncation: bool = True
|
||||||
@@ -147,57 +147,57 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_high_level_task(self, transition: EnvTransition) -> list[str] | None:
|
def get_prompt(self, transition: EnvTransition) -> list[str] | None:
|
||||||
"""
|
"""
|
||||||
Extracts the high-level task description(s) from the transition's complementary data.
|
Extracts the prompt (input for generation) from the transition's complementary data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transition: The environment transition.
|
transition: The environment transition.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of high-level task strings, or None if the high-level task key is not found or the value is None.
|
A list of prompt strings, or None if the prompt key is not found or the value is None.
|
||||||
"""
|
"""
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
if complementary_data is None:
|
if complementary_data is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
high_level_task = complementary_data.get(self.high_level_task_key)
|
prompt = complementary_data.get(self.prompt_key)
|
||||||
|
|
||||||
if high_level_task is None:
|
if prompt is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Standardize to a list of strings for the tokenizer
|
# Standardize to a list of strings for the tokenizer
|
||||||
if isinstance(high_level_task, str):
|
if isinstance(prompt, str):
|
||||||
return [high_level_task]
|
return [prompt]
|
||||||
elif isinstance(high_level_task, list) and all(isinstance(t, str) for t in high_level_task):
|
elif isinstance(prompt, list) and all(isinstance(t, str) for t in prompt):
|
||||||
return high_level_task
|
return prompt
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
def get_target(self, transition: EnvTransition) -> list[str] | None:
|
||||||
"""
|
"""
|
||||||
Extracts the subtask description(s) from the transition's complementary data.
|
Extracts the target (text to predict) from the transition's complementary data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
transition: The environment transition.
|
transition: The environment transition.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
A list of target strings, or None if the target key is not found or the value is None.
|
||||||
"""
|
"""
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
if complementary_data is None:
|
if complementary_data is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
subtask = complementary_data.get(self.subtask_key)
|
target = complementary_data.get(self.target_key)
|
||||||
|
|
||||||
if subtask is None:
|
if target is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Standardize to a list of strings for the tokenizer
|
# Standardize to a list of strings for the tokenizer
|
||||||
if isinstance(subtask, str):
|
if isinstance(target, str):
|
||||||
return [subtask]
|
return [target]
|
||||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
elif isinstance(target, list) and all(isinstance(t, str) for t in target):
|
||||||
return subtask
|
return target
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -238,39 +238,37 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
# Also tokenize high-level task if available
|
# Also tokenize prompt (input for generation) if available
|
||||||
high_level_task = self.get_high_level_task(self.transition)
|
prompt = self.get_prompt(self.transition)
|
||||||
if high_level_task is not None:
|
if prompt is not None:
|
||||||
# Tokenize the high-level task
|
tokenized_prompt_input = self._tokenize_text(prompt)
|
||||||
tokenized_high_level_prompt = self._tokenize_text(high_level_task)
|
|
||||||
|
|
||||||
# Move to the same device
|
# Move to the same device
|
||||||
if target_device is not None:
|
if target_device is not None:
|
||||||
tokenized_high_level_prompt = {
|
tokenized_prompt_input = {
|
||||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||||
for k, v in tokenized_high_level_prompt.items()
|
for k, v in tokenized_prompt_input.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add high-level tokenized data to the observation
|
# Add prompt tokenized data to the observation
|
||||||
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = tokenized_high_level_prompt["input_ids"]
|
new_observation[OBS_LANGUAGE_PROMPT_TOKENS] = tokenized_prompt_input["input_ids"]
|
||||||
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = tokenized_high_level_prompt["attention_mask"].to(dtype=torch.bool)
|
new_observation[OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = tokenized_prompt_input["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
# Also tokenize subtask if available
|
# Also tokenize target (text to predict) if available
|
||||||
subtask = self.get_subtask(self.transition)
|
target = self.get_target(self.transition)
|
||||||
if subtask is not None:
|
if target is not None:
|
||||||
# Tokenize the subtask
|
tokenized_target = self._tokenize_text(target)
|
||||||
tokenized_subtask_prompt = self._tokenize_text(subtask)
|
|
||||||
|
|
||||||
# Move to the same device
|
# Move to the same device
|
||||||
if target_device is not None:
|
if target_device is not None:
|
||||||
tokenized_subtask_prompt = {
|
tokenized_target = {
|
||||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||||
for k, v in tokenized_subtask_prompt.items()
|
for k, v in tokenized_target.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add subtask tokenized data to the observation
|
# Add target tokenized data to the observation
|
||||||
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = tokenized_subtask_prompt["input_ids"]
|
new_observation[OBS_LANGUAGE_TARGET_TOKENS] = tokenized_target["input_ids"]
|
||||||
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = tokenized_subtask_prompt["attention_mask"].to(dtype=torch.bool)
|
new_observation[OBS_LANGUAGE_TARGET_ATTENTION_MASK] = tokenized_target["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
return new_observation
|
return new_observation
|
||||||
|
|
||||||
@@ -332,7 +330,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
config = {
|
config = {
|
||||||
"max_length": self.max_length,
|
"max_length": self.max_length,
|
||||||
"task_key": self.task_key,
|
"task_key": self.task_key,
|
||||||
"high_level_task_key": self.high_level_task_key,
|
"prompt_key": self.prompt_key,
|
||||||
|
"target_key": self.target_key,
|
||||||
"padding_side": self.padding_side,
|
"padding_side": self.padding_side,
|
||||||
"padding": self.padding,
|
"padding": self.padding,
|
||||||
"truncation": self.truncation,
|
"truncation": self.truncation,
|
||||||
@@ -371,24 +370,25 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add features for high-level task tokens and attention mask if they don't already exist
|
# Add features for prompt tokens and attention mask if they don't already exist
|
||||||
if OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
if OBS_LANGUAGE_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = PolicyFeature(
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_TOKENS] = PolicyFeature(
|
||||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
if OBS_LANGUAGE_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = PolicyFeature(
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = PolicyFeature(
|
||||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if OBS_LANGUAGE_SUBTASK_ONLY_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
# Add features for target tokens and attention mask if they don't already exist
|
||||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = PolicyFeature(
|
if OBS_LANGUAGE_TARGET_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_TOKENS] = PolicyFeature(
|
||||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
if OBS_LANGUAGE_TARGET_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = PolicyFeature(
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_ATTENTION_MASK] = PolicyFeature(
|
||||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
|||||||
OBS_LANGUAGE = OBS_STR + ".language"
|
OBS_LANGUAGE = OBS_STR + ".language"
|
||||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK = OBS_STR + ".user_prompt"
|
OBS_LANGUAGE_PROMPT = OBS_STR + ".prompt"
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".tokens"
|
OBS_LANGUAGE_PROMPT_TOKENS = OBS_LANGUAGE_PROMPT + ".tokens"
|
||||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".attention_mask"
|
OBS_LANGUAGE_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_PROMPT + ".attention_mask"
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
|
OBS_LANGUAGE_TARGET = OBS_STR + ".target"
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
|
OBS_LANGUAGE_TARGET_TOKENS = OBS_LANGUAGE_TARGET + ".tokens"
|
||||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
|
OBS_LANGUAGE_TARGET_ATTENTION_MASK = OBS_LANGUAGE_TARGET + ".attention_mask"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
REWARD = "next.reward"
|
REWARD = "next.reward"
|
||||||
TRUNCATED = "next.truncated"
|
TRUNCATED = "next.truncated"
|
||||||
|
|||||||
Reference in New Issue
Block a user