mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
align fast more
This commit is contained in:
@@ -1419,17 +1419,25 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# For next-token prediction:
|
||||
# - Position (fast_start - 1) in input predicts fast_action_tokens[0]
|
||||
# - Position (fast_start) in input predicts fast_action_tokens[1], etc.
|
||||
T_lang = masks.shape[1]
|
||||
fast_start = total_T_images + T_lang
|
||||
|
||||
# Extract logits for FAST token prediction
|
||||
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
|
||||
fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
|
||||
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
|
||||
|
||||
# Targets are the FAST action tokens
|
||||
fast_targets = fast_action_tokens # (B, num_fast_embs)
|
||||
T_lang = masks.shape[1]
|
||||
fast_start = total_T_images + T_lang
|
||||
|
||||
|
||||
# Extract logits for FAST token prediction
|
||||
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
|
||||
# fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
|
||||
fast_hidden = prefix_out[:, -fast_targets.shape[1]:, :]
|
||||
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
|
||||
|
||||
# Shift left for next-step prediction and shift target
|
||||
# logits[:, i] predicts targets[:, i+1]
|
||||
fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # Shift logits left
|
||||
fast_targets = fast_targets[:, 1:] # Shift targets right
|
||||
fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets
|
||||
|
||||
# Compute cross-entropy loss
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
|
||||
@@ -1441,140 +1449,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Apply mask and compute mean loss
|
||||
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
|
||||
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
|
||||
|
||||
# breakpoint()
|
||||
# from transformers import AutoTokenizer, AutoProcessor
|
||||
# _paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# trust_remote_code=True,
|
||||
# add_eos_token=True,
|
||||
# add_bos_token=False
|
||||
# )
|
||||
# # 257152
|
||||
# # # Decode predicted output tokens
|
||||
# # # fast_logits_for_pred.argmax(dim=-1)
|
||||
# def _paligemma_tokens_to_act_tokens(tokens: torch.Tensor) -> torch.Tensor:
|
||||
# """
|
||||
# Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
# """
|
||||
# return _paligemma_tokenizer.vocab_size - 1 - 128 - tokens
|
||||
# # # target = _paligemma_tokens_to_act_tokens(fast_targets)
|
||||
# decoded_tokens = _paligemma_tokenizer.batch_decode(fast_targets, skip_special_tokens=False)
|
||||
# decoded_tokens = [
|
||||
# _paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
# for seq in fast_logits_for_pred.argmax(dim=-1)
|
||||
# ]
|
||||
# cleaned_tokens = []
|
||||
# for token_seq in decoded_tokens:
|
||||
# if "|" in token_seq:
|
||||
# token_seq = token_seq[:token_seq.index("|")]
|
||||
# cleaned_tokens.append(token_seq)
|
||||
# raw_action_tokens = [
|
||||
# torch.tensor(
|
||||
# _paligemma_tokenizer.convert_tokens_to_ids(token_seq),
|
||||
# dtype=torch.long,
|
||||
# device=fast_targets.device,
|
||||
# )
|
||||
# for token_seq in cleaned_tokens
|
||||
# ]
|
||||
|
||||
# action_tokens = [
|
||||
# _paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
# for raw_action_token in raw_action_tokens
|
||||
# ]
|
||||
# breakpoint()
|
||||
# # Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
|
||||
# cleaned_tokens = [
|
||||
# tokens_sequence.strip().split("|")[0].strip()
|
||||
# for tokens_sequence in decoded_tokens
|
||||
# ]
|
||||
|
||||
# # Re-encode the cleaned text to get raw action tokens
|
||||
# raw_action_tokens = [
|
||||
# _paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False).squeeze(0)
|
||||
# for sample_tokens in cleaned_tokens
|
||||
# ]
|
||||
# # Convert PaliGemma tokens back to action tokens
|
||||
# action_tokens = [
|
||||
# _paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
# for raw_action_token in raw_action_tokens
|
||||
# ]
|
||||
# # # Decode each sample's tokens to continuous actions
|
||||
# action_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
# # breakpoint()
|
||||
# decoded_actions = action_tokenizer.decode(
|
||||
# action_tokens,
|
||||
# time_horizon=self.config.chunk_size,
|
||||
# action_dim=6
|
||||
# )
|
||||
# breakpoint()
|
||||
# def decode_actions_with_fast(
|
||||
# token_ids: list[int],
|
||||
# time_horizon: int,
|
||||
# action_dim: int,
|
||||
# relaxed_decoding: bool = False
|
||||
# ) -> list:
|
||||
# """
|
||||
# Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
# Args:
|
||||
# token_ids: List of token IDs to decode.
|
||||
# time_horizon: The number of timesteps for actions.
|
||||
# action_dim: The dimensionality of each action.
|
||||
# relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
# Returns:
|
||||
# A list representing the decoded actions.
|
||||
# """
|
||||
# # Use the action tokenizer's decode method
|
||||
# # The FAST tokenizer should have a decode method that converts tokens back to actions
|
||||
# try:
|
||||
# decoded_actions = action_tokenizer.decode(
|
||||
# token_ids,
|
||||
# time_horizon=time_horizon,
|
||||
# action_dim=action_dim
|
||||
# )
|
||||
# return decoded_actions
|
||||
# except Exception as e:
|
||||
# if relaxed_decoding:
|
||||
# # If relaxed decoding is enabled, try to decode as much as possible
|
||||
# import logging
|
||||
# logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
|
||||
# try:
|
||||
# # Try to decode with whatever tokens we have
|
||||
# partial_decoded = action_tokenizer.decode(
|
||||
# token_ids[:len(token_ids)],
|
||||
# time_horizon=time_horizon,
|
||||
# action_dim=action_dim
|
||||
# )
|
||||
# return partial_decoded
|
||||
# except:
|
||||
# # Return zeros if decoding completely fails
|
||||
# return [[0.0] * action_dim for _ in range(time_horizon)]
|
||||
# else:
|
||||
# raise e
|
||||
|
||||
# valid = fast_logits_for_pred.argmax(dim=-1) <= (self._paligemma_tokenizer.vocab_size - 1 - 128)
|
||||
# fast_region = fast_logits_for_pred.argmax(dim=-1).masked_fill(~valid, 0)
|
||||
# fast_tokens = _paligemma_tokens_to_act_tokens(fast_region)
|
||||
# actions = decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.config.chunk_size, action_dim=7, relaxed_decoding=True)[0]
|
||||
# breakpoint()
|
||||
# decoded_actions = [
|
||||
# torch.tensor(
|
||||
# decode_actions_with_fast(
|
||||
# tok[0].tolist(),
|
||||
# time_horizon=self.config.chunk_size,
|
||||
# action_dim=7,
|
||||
# relaxed_decoding=True,
|
||||
# ),
|
||||
# device=tokens.device,
|
||||
# ).squeeze(0)
|
||||
# for tok in action_tokens
|
||||
# ]
|
||||
# breakpoint()
|
||||
# # Stack into a batch
|
||||
# result = torch.stack(decoded_actions, dim=0)
|
||||
# breakpoint()
|
||||
return {
|
||||
"fast_loss": fast_loss,
|
||||
"loss": fast_loss,
|
||||
|
||||
@@ -101,7 +101,7 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
cleaned_high_level_task = cleaned_high_level_tasks[i]
|
||||
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
|
||||
else:
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" #remove Action by jade
|
||||
|
||||
low_level_prompts.append(full_prompt)
|
||||
|
||||
|
||||
@@ -9,21 +9,21 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \
|
||||
--job_name=libero_training_fast \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero-8 \
|
||||
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=60000 \
|
||||
--save_freq=10000 \
|
||||
--steps=120000 \
|
||||
--save_freq=12000 \
|
||||
--batch_size=8 \
|
||||
--policy.compile_model=false \
|
||||
--policy.device=cuda \
|
||||
--policy.fast_only=true \
|
||||
--policy.scheduler_warmup_steps=2000 \
|
||||
--policy.scheduler_decay_steps=60000 \
|
||||
--policy.scheduler_warmup_steps=4000 \
|
||||
--policy.scheduler_decay_steps=120000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=pi05-libero-training \
|
||||
--wandb.project=pi05-libero-training \
|
||||
|
||||
@@ -353,7 +353,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
else:
|
||||
# If at max length, replace the last token with EOS
|
||||
input_ids[i, last_token_pos] = eos_token_id
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
@@ -577,8 +576,13 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# Flatten to 1D if needed
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
tokens = torch.cat([
|
||||
torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
])
|
||||
|
||||
tokens = torch.cat([self._act_tokens_to_paligemma_tokens(tokens), torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)])
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
import logging
|
||||
@@ -843,7 +847,6 @@ class ActionDetokenizerProcessorStep1(ActionProcessorStep):
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
tokens = [t.flatten().tolist() for t in action_tokens]
|
||||
breakpoint()
|
||||
# Decode each sample's tokens to continuous actions
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
@@ -857,7 +860,6 @@ class ActionDetokenizerProcessorStep1(ActionProcessorStep):
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
breakpoint()
|
||||
# Stack into a batch
|
||||
result = torch.stack(decoded_actions, dim=0)
|
||||
|
||||
|
||||
@@ -91,10 +91,10 @@ def update_policy(
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
action = policy.predict_action_chunk(batch)
|
||||
if postprocessor is not None:
|
||||
action = postprocessor(action)
|
||||
breakpoint()
|
||||
# action = policy.predict_action_chunk(batch)
|
||||
# if postprocessor is not None:
|
||||
# action = postprocessor(action)
|
||||
# breakpoint()
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
|
||||
Reference in New Issue
Block a user