mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
add training
This commit is contained in:
@@ -60,8 +60,8 @@ class PI05Config(PreTrainedConfig):
|
|||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.IDENTITY,
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
|
"STATE": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for state
|
||||||
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
|
"ACTION": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for action
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ from lerobot.utils.constants import (
|
|||||||
ACTION,
|
ACTION,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
||||||
|
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS,
|
||||||
|
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK,
|
||||||
OPENPI_ATTENTION_MASK_VALUE,
|
OPENPI_ATTENTION_MASK_VALUE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -429,6 +432,8 @@ class PaliGemmaWithExpertModel(
|
|||||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||||
)
|
)
|
||||||
prefix_past_key_values = prefix_output.past_key_values
|
prefix_past_key_values = prefix_output.past_key_values
|
||||||
|
# prefix_output to be used for the language head
|
||||||
|
# shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048
|
||||||
prefix_output = prefix_output.last_hidden_state
|
prefix_output = prefix_output.last_hidden_state
|
||||||
suffix_output = None
|
suffix_output = None
|
||||||
elif inputs_embeds[0] is None:
|
elif inputs_embeds[0] is None:
|
||||||
@@ -578,10 +583,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
|
||||||
"""Helper method to prepare 4D attention masks for transformer."""
|
"""Helper method to prepare 4D attention masks for transformer."""
|
||||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||||
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||||
|
if dtype is not None:
|
||||||
|
result = result.to(dtype=dtype)
|
||||||
|
return result
|
||||||
|
|
||||||
def sample_noise(self, shape, device):
|
def sample_noise(self, shape, device):
|
||||||
return torch.normal(
|
return torch.normal(
|
||||||
@@ -600,13 +608,28 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
return time.to(dtype=torch.float32, device=device)
|
return time.to(dtype=torch.float32, device=device)
|
||||||
|
|
||||||
def embed_prefix(
|
def embed_prefix(
|
||||||
self, images, img_masks, tokens, masks
|
self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||||
"""Embed images with SigLIP and language tokens with embedding layer."""
|
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of image tensors
|
||||||
|
img_masks: List of image masks
|
||||||
|
tokens: Language instruction tokens
|
||||||
|
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
||||||
|
masks: Attention masks for tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)]
|
||||||
|
pad_masks: Padding masks
|
||||||
|
att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided)
|
||||||
|
total_T_images: Total number of image tokens
|
||||||
|
"""
|
||||||
embs = []
|
embs = []
|
||||||
pad_masks = []
|
pad_masks = []
|
||||||
att_masks = []
|
att_masks = []
|
||||||
|
total_T_images = 0
|
||||||
|
|
||||||
# Process images
|
# Process images
|
||||||
for img, img_mask in zip(images, img_masks, strict=True):
|
for img, img_mask in zip(images, img_masks, strict=True):
|
||||||
|
|
||||||
@@ -618,9 +641,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
embs.append(img_emb)
|
embs.append(img_emb)
|
||||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||||
att_masks += [0] * num_img_embs
|
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
|
||||||
|
total_T_images += num_img_embs
|
||||||
# Process language tokens
|
|
||||||
|
# Process language instruction tokens
|
||||||
def lang_embed_func(tokens):
|
def lang_embed_func(tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
lang_emb_dim = lang_emb.shape[-1]
|
||||||
@@ -631,16 +655,34 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
pad_masks.append(masks)
|
pad_masks.append(masks)
|
||||||
|
|
||||||
num_lang_embs = lang_emb.shape[1]
|
num_lang_embs = lang_emb.shape[1]
|
||||||
att_masks += [0] * num_lang_embs
|
att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens)
|
||||||
|
|
||||||
|
# Process subtask tokens if provided (these are predicted, so use causal masking)
|
||||||
|
if subtask_tokens is not None:
|
||||||
|
def subtask_embed_func(subtask_tokens):
|
||||||
|
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
|
||||||
|
subtask_emb_dim = subtask_emb.shape[-1]
|
||||||
|
return subtask_emb * math.sqrt(subtask_emb_dim)
|
||||||
|
|
||||||
|
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
|
||||||
|
embs.append(subtask_emb)
|
||||||
|
|
||||||
|
# Create subtask pad masks (non-zero tokens are valid)
|
||||||
|
pad_masks.append(subtask_masks)
|
||||||
|
|
||||||
|
num_subtask_embs = subtask_emb.shape[1]
|
||||||
|
# Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens,
|
||||||
|
# and previous subtask tokens
|
||||||
|
att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||||
|
|
||||||
bsize = pad_masks.shape[0]
|
bsize = pad_masks.shape[0]
|
||||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
|
||||||
|
|
||||||
return embs, pad_masks, att_masks
|
return embs, pad_masks, att_masks, total_T_images
|
||||||
|
|
||||||
def embed_suffix(self, noisy_actions, timestep):
|
def embed_suffix(self, noisy_actions, timestep):
|
||||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||||
@@ -689,7 +731,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions)
|
||||||
|
def forward(self, images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
noise = self.sample_noise(actions.shape, actions.device)
|
||||||
@@ -701,9 +744,55 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
# Embed prefix (images + tokens + subtask_tokens)
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||||
|
images, img_masks, tokens, subtask_tokens, masks, subtask_masks
|
||||||
|
)
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||||
|
|
||||||
|
# Prepare attention masks for prefix-only pass (for subtask token prediction)
|
||||||
|
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
|
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||||
|
|
||||||
|
# prefix-only transformer run for subtask token prediction
|
||||||
|
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||||
|
attention_mask=att_2d_prefix_4d,
|
||||||
|
position_ids=position_ids_prefix,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[prefix_embs, None], # SUFFIX = None
|
||||||
|
use_cache=False,
|
||||||
|
adarms_cond=[None, None],
|
||||||
|
)
|
||||||
|
|
||||||
|
# LM HEAD → SUBTASK LOGITS
|
||||||
|
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_tokens + T_subtask
|
||||||
|
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||||
|
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||||
|
|
||||||
|
# Extract logits for subtask token prediction
|
||||||
|
# Subtask tokens start after images and instruction tokens
|
||||||
|
T_tokens = tokens.size(1)
|
||||||
|
T_subtask = subtask_tokens.size(1)
|
||||||
|
start_index = total_T_images + T_tokens
|
||||||
|
end_index = start_index + T_subtask
|
||||||
|
logits_subtask = logits[:, start_index:end_index, :] # (B, T_subtask, vocab)
|
||||||
|
|
||||||
|
targets = subtask_tokens # (B, T_subtask)
|
||||||
|
# Compute cross-entropy loss
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||||
|
# Reshape for loss computation
|
||||||
|
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab)
|
||||||
|
targets_flat = targets.reshape(-1) # (B*T_subtask)
|
||||||
|
|
||||||
|
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask)
|
||||||
|
loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask)
|
||||||
|
|
||||||
|
# Apply mask and compute mean loss over valid tokens
|
||||||
|
masked_loss = loss_per_token * subtask_masks.float()
|
||||||
|
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
||||||
|
|
||||||
|
# Convert embeddings to bfloat16 if needed for the model
|
||||||
if (
|
if (
|
||||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||||
== torch.bfloat16
|
== torch.bfloat16
|
||||||
@@ -711,13 +800,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks
|
||||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||||
|
|
||||||
|
# Prepare attention masks for full forward pass (prefix + suffix)
|
||||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||||
|
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
|
||||||
|
|
||||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||||
@@ -728,6 +818,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
adarms_cond=[None, adarms_cond],
|
adarms_cond=[None, adarms_cond],
|
||||||
)
|
)
|
||||||
|
# prefix_out to be used for the language head
|
||||||
return suffix_out
|
return suffix_out
|
||||||
|
|
||||||
suffix_out = self._apply_checkpoint(
|
suffix_out = self._apply_checkpoint(
|
||||||
@@ -742,7 +833,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||||
|
|
||||||
return F.mse_loss(u_t, v_t, reduction="none")
|
fm_loss = F.mse_loss(u_t, v_t, reduction="none")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"flow_loss": fm_loss,
|
||||||
|
"subtask_loss": subtask_loss,
|
||||||
|
"loss": 10 * fm_loss.mean() + subtask_loss,
|
||||||
|
}
|
||||||
|
|
||||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||||
def sample_actions(
|
def sample_actions(
|
||||||
@@ -771,11 +868,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
) # Use config max_action_dim for internal processing
|
) # Use config max_action_dim for internal processing
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device)
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
# During inference, we don't need subtask_tokens, so pass None
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
|
||||||
|
images, img_masks, tokens, subtask_tokens=None, masks=masks
|
||||||
|
)
|
||||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
|
||||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
_, past_key_values = self.paligemma_with_expert.forward(
|
_, past_key_values = self.paligemma_with_expert.forward(
|
||||||
@@ -852,7 +952,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=suffix_embs.dtype)
|
||||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||||
|
|
||||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||||
@@ -1198,7 +1298,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
|
|
||||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||||
|
|
||||||
@@ -1214,21 +1314,22 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
|
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK}"]
|
||||||
|
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
# high_level_task = instruction tokens, tokens = subtask tokens to predict
|
||||||
|
loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Extract the total loss
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
loss = loss_dict["loss"]
|
||||||
losses = losses[:, :, :original_action_dim]
|
|
||||||
|
# Prepare detailed loss dictionary for logging
|
||||||
loss = losses.mean()
|
detailed_loss_dict = {
|
||||||
|
|
||||||
loss_dict = {
|
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
||||||
|
"subtask_loss": loss_dict["subtask_loss"].item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return loss, loss_dict
|
return loss, detailed_loss_dict
|
||||||
|
|||||||
@@ -47,13 +47,15 @@ from lerobot.utils.constants import (
|
|||||||
|
|
||||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||||
"""
|
"""
|
||||||
Processor step to prepare the state and tokenize the language input.
|
Processor step to prepare the state and tokenize the language input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_state_dim: int = 32
|
max_state_dim: int = 32
|
||||||
task_key: str = "task"
|
task_key: str = "task"
|
||||||
|
high_level_task_key: str = "user_prompt"
|
||||||
|
subtask_only_key: str = "subtask"
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
transition = transition.copy()
|
transition = transition.copy()
|
||||||
@@ -64,6 +66,8 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
|||||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||||
if tasks is None:
|
if tasks is None:
|
||||||
raise ValueError("No task found in complementary data")
|
raise ValueError("No task found in complementary data")
|
||||||
|
|
||||||
|
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.high_level_task_key)
|
||||||
|
|
||||||
# TODO: check if this necessary
|
# TODO: check if this necessary
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
@@ -76,16 +80,42 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
|||||||
state_np = state.cpu().numpy()
|
state_np = state.cpu().numpy()
|
||||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||||
|
|
||||||
full_prompts = []
|
# Clean high level tasks first (if available)
|
||||||
|
cleaned_high_level_tasks = []
|
||||||
|
if high_level_tasks is not None:
|
||||||
|
for high_level_task in high_level_tasks:
|
||||||
|
cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " "))
|
||||||
|
|
||||||
|
# Process low level tasks with state information
|
||||||
|
low_level_prompts = []
|
||||||
|
subtask_only_prompts = [] # Store only the subtask text for prediction
|
||||||
for i, task in enumerate(tasks):
|
for i, task in enumerate(tasks):
|
||||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||||
state_str = " ".join(map(str, discretized_states[i]))
|
state_str = " ".join(map(str, discretized_states[i]))
|
||||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
|
||||||
full_prompts.append(full_prompt)
|
# Store only the subtask text (used as prediction target)
|
||||||
|
subtask_only_prompts.append(cleaned_text)
|
||||||
|
|
||||||
|
if cleaned_high_level_tasks:
|
||||||
|
cleaned_high_level_task = cleaned_high_level_tasks[i]
|
||||||
|
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
|
||||||
|
else:
|
||||||
|
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||||
|
|
||||||
|
low_level_prompts.append(full_prompt)
|
||||||
|
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts
|
||||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_only_key] = subtask_only_prompts
|
||||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
|
||||||
|
# Process high level tasks without state information (if available)
|
||||||
|
if high_level_tasks is not None:
|
||||||
|
high_level_prompts = []
|
||||||
|
for i, cleaned_high_level_task in enumerate(cleaned_high_level_tasks):
|
||||||
|
state_str = " ".join(map(str, discretized_states[i]))
|
||||||
|
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
|
||||||
|
high_level_prompts.append(full_prompt)
|
||||||
|
|
||||||
|
transition[TransitionKey.COMPLEMENTARY_DATA][self.high_level_task_key] = high_level_prompts
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
def transform_features(
|
def transform_features(
|
||||||
@@ -133,14 +163,14 @@ def make_pi05_pre_post_processors(
|
|||||||
input_steps: list[ProcessorStep] = [
|
input_steps: list[ProcessorStep] = [
|
||||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||||
AddBatchDimensionProcessorStep(),
|
AddBatchDimensionProcessorStep(),
|
||||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateAndLanguageTokenizerProcessorStep
|
||||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||||
NormalizerProcessorStep(
|
NormalizerProcessorStep(
|
||||||
features={**config.input_features, **config.output_features},
|
features={**config.input_features, **config.output_features},
|
||||||
norm_map=config.normalization_mapping,
|
norm_map=config.normalization_mapping,
|
||||||
stats=dataset_stats,
|
stats=dataset_stats,
|
||||||
),
|
),
|
||||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
Pi05PrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||||
TokenizerProcessorStep(
|
TokenizerProcessorStep(
|
||||||
tokenizer_name="google/paligemma-3b-pt-224",
|
tokenizer_name="google/paligemma-3b-pt-224",
|
||||||
max_length=config.tokenizer_max_length,
|
max_length=config.tokenizer_max_length,
|
||||||
|
|||||||
@@ -168,10 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
|||||||
"""
|
"""
|
||||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||||
|
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
|
||||||
|
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||||
|
|
||||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
return {**pad_keys, **task_key, **index_key, **task_index_key, **user_prompt_key, **subtask_key}
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ class RenameObservationsProcessorStep(ObservationProcessorStep):
|
|||||||
processed_obs[self.rename_map[key]] = value
|
processed_obs[self.rename_map[key]] = value
|
||||||
else:
|
else:
|
||||||
processed_obs[key] = value
|
processed_obs[key] = value
|
||||||
|
|
||||||
return processed_obs
|
return processed_obs
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -29,7 +29,14 @@ from typing import TYPE_CHECKING, Any
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
from lerobot.utils.constants import (
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS,
|
||||||
|
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK,
|
||||||
|
)
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvTransition, TransitionKey
|
||||||
@@ -52,6 +59,9 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
||||||
token IDs and attention mask to the `observation` dictionary.
|
token IDs and attention mask to the `observation` dictionary.
|
||||||
|
|
||||||
|
Optionally, this step can also tokenize a high-level task (e.g., user prompt) and/or
|
||||||
|
a subtask if present in the complementary data, creating separate tokenized observations.
|
||||||
|
|
||||||
Requires the `transformers` library to be installed.
|
Requires the `transformers` library to be installed.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
@@ -59,6 +69,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||||
max_length: The maximum length to pad or truncate sequences to.
|
max_length: The maximum length to pad or truncate sequences to.
|
||||||
task_key: The key in `complementary_data` where the task string is stored.
|
task_key: The key in `complementary_data` where the task string is stored.
|
||||||
|
high_level_task_key: The key in `complementary_data` where the high-level task (user prompt) is stored.
|
||||||
|
subtask_key: The key in `complementary_data` where the subtask string is stored.
|
||||||
padding_side: The side to pad on ('left' or 'right').
|
padding_side: The side to pad on ('left' or 'right').
|
||||||
padding: The padding strategy ('max_length', 'longest', etc.).
|
padding: The padding strategy ('max_length', 'longest', etc.).
|
||||||
truncation: Whether to truncate sequences longer than `max_length`.
|
truncation: Whether to truncate sequences longer than `max_length`.
|
||||||
@@ -69,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
||||||
max_length: int = 512
|
max_length: int = 512
|
||||||
task_key: str = "task"
|
task_key: str = "task"
|
||||||
|
high_level_task_key: str = "user_prompt"
|
||||||
|
subtask_key: str = "subtask"
|
||||||
padding_side: str = "right"
|
padding_side: str = "right"
|
||||||
padding: str = "max_length"
|
padding: str = "max_length"
|
||||||
truncation: bool = True
|
truncation: bool = True
|
||||||
@@ -121,6 +135,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
raise ValueError("Complementary data is None so no task can be extracted from it")
|
raise ValueError("Complementary data is None so no task can be extracted from it")
|
||||||
|
|
||||||
task = complementary_data[self.task_key]
|
task = complementary_data[self.task_key]
|
||||||
|
|
||||||
if task is None:
|
if task is None:
|
||||||
raise ValueError("Task extracted from Complementary data is None")
|
raise ValueError("Task extracted from Complementary data is None")
|
||||||
|
|
||||||
@@ -132,6 +147,60 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_high_level_task(self, transition: EnvTransition) -> list[str] | None:
|
||||||
|
"""
|
||||||
|
Extracts the high-level task description(s) from the transition's complementary data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transition: The environment transition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of high-level task strings, or None if the high-level task key is not found or the value is None.
|
||||||
|
"""
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
if complementary_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
high_level_task = complementary_data.get(self.high_level_task_key)
|
||||||
|
|
||||||
|
if high_level_task is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Standardize to a list of strings for the tokenizer
|
||||||
|
if isinstance(high_level_task, str):
|
||||||
|
return [high_level_task]
|
||||||
|
elif isinstance(high_level_task, list) and all(isinstance(t, str) for t in high_level_task):
|
||||||
|
return high_level_task
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||||
|
"""
|
||||||
|
Extracts the subtask description(s) from the transition's complementary data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transition: The environment transition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||||
|
"""
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
if complementary_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
subtask = complementary_data.get(self.subtask_key)
|
||||||
|
|
||||||
|
if subtask is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Standardize to a list of strings for the tokenizer
|
||||||
|
if isinstance(subtask, str):
|
||||||
|
return [subtask]
|
||||||
|
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||||
|
return subtask
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Tokenizes the task description and adds it to the observation dictionary.
|
Tokenizes the task description and adds it to the observation dictionary.
|
||||||
@@ -169,6 +238,40 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
|
# Also tokenize high-level task if available
|
||||||
|
high_level_task = self.get_high_level_task(self.transition)
|
||||||
|
if high_level_task is not None:
|
||||||
|
# Tokenize the high-level task
|
||||||
|
tokenized_high_level_prompt = self._tokenize_text(high_level_task)
|
||||||
|
|
||||||
|
# Move to the same device
|
||||||
|
if target_device is not None:
|
||||||
|
tokenized_high_level_prompt = {
|
||||||
|
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in tokenized_high_level_prompt.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add high-level tokenized data to the observation
|
||||||
|
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = tokenized_high_level_prompt["input_ids"]
|
||||||
|
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = tokenized_high_level_prompt["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
|
# Also tokenize subtask if available
|
||||||
|
subtask = self.get_subtask(self.transition)
|
||||||
|
if subtask is not None:
|
||||||
|
# Tokenize the subtask
|
||||||
|
tokenized_subtask_prompt = self._tokenize_text(subtask)
|
||||||
|
|
||||||
|
# Move to the same device
|
||||||
|
if target_device is not None:
|
||||||
|
tokenized_subtask_prompt = {
|
||||||
|
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in tokenized_subtask_prompt.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add subtask tokenized data to the observation
|
||||||
|
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = tokenized_subtask_prompt["input_ids"]
|
||||||
|
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = tokenized_subtask_prompt["attention_mask"].to(dtype=torch.bool)
|
||||||
|
|
||||||
return new_observation
|
return new_observation
|
||||||
|
|
||||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||||
@@ -229,6 +332,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
config = {
|
config = {
|
||||||
"max_length": self.max_length,
|
"max_length": self.max_length,
|
||||||
"task_key": self.task_key,
|
"task_key": self.task_key,
|
||||||
|
"high_level_task_key": self.high_level_task_key,
|
||||||
"padding_side": self.padding_side,
|
"padding_side": self.padding_side,
|
||||||
"padding": self.padding,
|
"padding": self.padding,
|
||||||
"truncation": self.truncation,
|
"truncation": self.truncation,
|
||||||
@@ -267,4 +371,25 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add features for high-level task tokens and attention mask if they don't already exist
|
||||||
|
if OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_LANGUAGE_SUBTASK_ONLY_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||||
|
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
|
)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|||||||
@@ -26,7 +26,12 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
|||||||
OBS_LANGUAGE = OBS_STR + ".language"
|
OBS_LANGUAGE = OBS_STR + ".language"
|
||||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||||
|
OBS_LANGUAGE_HIGH_LEVEL_TASK = OBS_STR + ".user_prompt"
|
||||||
|
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".tokens"
|
||||||
|
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".attention_mask"
|
||||||
|
OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
|
||||||
|
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
|
||||||
|
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
REWARD = "next.reward"
|
REWARD = "next.reward"
|
||||||
TRUNCATED = "next.truncated"
|
TRUNCATED = "next.truncated"
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ def create_original_observation_with_openpi_preprocessing(batch):
|
|||||||
elif len(tasks) == 1:
|
elif len(tasks) == 1:
|
||||||
tasks = tasks * batch_size
|
tasks = tasks * batch_size
|
||||||
|
|
||||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateAndLanguageTokenizerProcessorStep)
|
||||||
state = batch["observation.state"]
|
state = batch["observation.state"]
|
||||||
state = deepcopy(state)
|
state = deepcopy(state)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user