align fast more

This commit is contained in:
Jade Choghari
2025-12-26 17:24:39 +00:00
parent f0923e5c86
commit 4b40153c32
5 changed files with 32 additions and 155 deletions
+15 -140
View File
@@ -1419,17 +1419,25 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# For next-token prediction:
# - Position (fast_start - 1) in input predicts fast_action_tokens[0]
# - Position (fast_start) in input predicts fast_action_tokens[1], etc.
T_lang = masks.shape[1]
fast_start = total_T_images + T_lang
# Extract logits for FAST token prediction
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
# Targets are the FAST action tokens
fast_targets = fast_action_tokens # (B, num_fast_embs)
T_lang = masks.shape[1]
fast_start = total_T_images + T_lang
# Extract logits for FAST token prediction
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
# fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
fast_hidden = prefix_out[:, -fast_targets.shape[1]:, :]
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
# Shift left for next-step prediction and shift target
# logits[:, i] predicts targets[:, i+1]
fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # Shift logits left
fast_targets = fast_targets[:, 1:] # Shift targets right
fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets
# Compute cross-entropy loss
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
@@ -1441,140 +1449,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Apply mask and compute mean loss
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
# breakpoint()
# from transformers import AutoTokenizer, AutoProcessor
# _paligemma_tokenizer = AutoTokenizer.from_pretrained(
# "google/paligemma-3b-pt-224",
# trust_remote_code=True,
# add_eos_token=True,
# add_bos_token=False
# )
# # 257152
# # # Decode predicted output tokens
# # # fast_logits_for_pred.argmax(dim=-1)
# def _paligemma_tokens_to_act_tokens(tokens: torch.Tensor) -> torch.Tensor:
# """
# Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
# """
# return _paligemma_tokenizer.vocab_size - 1 - 128 - tokens
# # # target = _paligemma_tokens_to_act_tokens(fast_targets)
# decoded_tokens = _paligemma_tokenizer.batch_decode(fast_targets, skip_special_tokens=False)
# decoded_tokens = [
# _paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_logits_for_pred.argmax(dim=-1)
# ]
# cleaned_tokens = []
# for token_seq in decoded_tokens:
# if "|" in token_seq:
# token_seq = token_seq[:token_seq.index("|")]
# cleaned_tokens.append(token_seq)
# raw_action_tokens = [
# torch.tensor(
# _paligemma_tokenizer.convert_tokens_to_ids(token_seq),
# dtype=torch.long,
# device=fast_targets.device,
# )
# for token_seq in cleaned_tokens
# ]
# action_tokens = [
# _paligemma_tokens_to_act_tokens(raw_action_token)
# for raw_action_token in raw_action_tokens
# ]
# breakpoint()
# # Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
# cleaned_tokens = [
# tokens_sequence.strip().split("|")[0].strip()
# for tokens_sequence in decoded_tokens
# ]
# # Re-encode the cleaned text to get raw action tokens
# raw_action_tokens = [
# _paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False).squeeze(0)
# for sample_tokens in cleaned_tokens
# ]
# # Convert PaliGemma tokens back to action tokens
# action_tokens = [
# _paligemma_tokens_to_act_tokens(raw_action_token)
# for raw_action_token in raw_action_tokens
# ]
# # # Decode each sample's tokens to continuous actions
# action_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# # breakpoint()
# decoded_actions = action_tokenizer.decode(
# action_tokens,
# time_horizon=self.config.chunk_size,
# action_dim=6
# )
# breakpoint()
# def decode_actions_with_fast(
# token_ids: list[int],
# time_horizon: int,
# action_dim: int,
# relaxed_decoding: bool = False
# ) -> list:
# """
# Decodes action token IDs back to continuous action values using the FAST tokenizer.
# Args:
# token_ids: List of token IDs to decode.
# time_horizon: The number of timesteps for actions.
# action_dim: The dimensionality of each action.
# relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
# Returns:
# A list representing the decoded actions.
# """
# # Use the action tokenizer's decode method
# # The FAST tokenizer should have a decode method that converts tokens back to actions
# try:
# decoded_actions = action_tokenizer.decode(
# token_ids,
# time_horizon=time_horizon,
# action_dim=action_dim
# )
# return decoded_actions
# except Exception as e:
# if relaxed_decoding:
# # If relaxed decoding is enabled, try to decode as much as possible
# import logging
# logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
# try:
# # Try to decode with whatever tokens we have
# partial_decoded = action_tokenizer.decode(
# token_ids[:len(token_ids)],
# time_horizon=time_horizon,
# action_dim=action_dim
# )
# return partial_decoded
# except:
# # Return zeros if decoding completely fails
# return [[0.0] * action_dim for _ in range(time_horizon)]
# else:
# raise e
# valid = fast_logits_for_pred.argmax(dim=-1) <= (self._paligemma_tokenizer.vocab_size - 1 - 128)
# fast_region = fast_logits_for_pred.argmax(dim=-1).masked_fill(~valid, 0)
# fast_tokens = _paligemma_tokens_to_act_tokens(fast_region)
# actions = decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.config.chunk_size, action_dim=7, relaxed_decoding=True)[0]
# breakpoint()
# decoded_actions = [
# torch.tensor(
# decode_actions_with_fast(
# tok[0].tolist(),
# time_horizon=self.config.chunk_size,
# action_dim=7,
# relaxed_decoding=True,
# ),
# device=tokens.device,
# ).squeeze(0)
# for tok in action_tokens
# ]
# breakpoint()
# # Stack into a batch
# result = torch.stack(decoded_actions, dim=0)
# breakpoint()
return {
"fast_loss": fast_loss,
"loss": fast_loss,
+1 -1
View File
@@ -101,7 +101,7 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
cleaned_high_level_task = cleaned_high_level_tasks[i]
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
else:
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" #remove Action by jade
low_level_prompts.append(full_prompt)
+6 -6
View File
@@ -9,21 +9,21 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero-8 \
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
--policy.dtype=bfloat16 \
--steps=60000 \
--save_freq=10000 \
--steps=120000 \
--save_freq=12000 \
--batch_size=8 \
--policy.compile_model=false \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=2000 \
--policy.scheduler_decay_steps=60000 \
--policy.scheduler_warmup_steps=4000 \
--policy.scheduler_decay_steps=120000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=false \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05-libero-training \
--wandb.project=pi05-libero-training \
+6 -4
View File
@@ -353,7 +353,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
else:
# If at max length, replace the last token with EOS
input_ids[i, last_token_pos] = eos_token_id
return {"input_ids": input_ids, "attention_mask": attention_mask}
def get_config(self) -> dict[str, Any]:
@@ -577,8 +576,13 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# Flatten to 1D if needed
if tokens.dim() > 1:
tokens = tokens.flatten()
tokens = torch.cat([
torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
])
tokens = torch.cat([self._act_tokens_to_paligemma_tokens(tokens), torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)])
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
import logging
@@ -843,7 +847,6 @@ class ActionDetokenizerProcessorStep1(ActionProcessorStep):
for raw_action_token in raw_action_tokens
]
tokens = [t.flatten().tolist() for t in action_tokens]
breakpoint()
# Decode each sample's tokens to continuous actions
decoded_actions = [
torch.tensor(
@@ -857,7 +860,6 @@ class ActionDetokenizerProcessorStep1(ActionProcessorStep):
).squeeze(0)
for tok in action_tokens
]
breakpoint()
# Stack into a batch
result = torch.stack(decoded_actions, dim=0)
+4 -4
View File
@@ -91,10 +91,10 @@ def update_policy(
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
action = policy.predict_action_chunk(batch)
if postprocessor is not None:
action = postprocessor(action)
breakpoint()
# action = policy.predict_action_chunk(batch)
# if postprocessor is not None:
# action = postprocessor(action)
# breakpoint()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method