diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 5e15fee67..164484cae 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -1419,17 +1419,25 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # For next-token prediction: # - Position (fast_start - 1) in input predicts fast_action_tokens[0] # - Position (fast_start) in input predicts fast_action_tokens[1], etc. - T_lang = masks.shape[1] - fast_start = total_T_images + T_lang - - # Extract logits for FAST token prediction - # Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens - fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim) - fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size) # Targets are the FAST action tokens fast_targets = fast_action_tokens # (B, num_fast_embs) + T_lang = masks.shape[1] + fast_start = total_T_images + T_lang + + # Extract logits for FAST token prediction + # Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens + # fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim) + fast_hidden = prefix_out[:, -fast_targets.shape[1]:, :] + fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size) + + # Shift left for next-step prediction and shift target + # logits[:, i] predicts targets[:, i+1] + fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # Shift logits left + fast_targets = fast_targets[:, 1:] # Shift targets right + fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets + # Compute cross-entropy loss loss_fct = torch.nn.CrossEntropyLoss(reduction='none') fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) @@ -1441,140 +1449,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Apply mask and compute mean loss masked_fast_loss = fast_loss_per_token * fast_action_masks.float() fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1) - - # breakpoint() - # from transformers import AutoTokenizer, AutoProcessor - # _paligemma_tokenizer = AutoTokenizer.from_pretrained( - # "google/paligemma-3b-pt-224", - # trust_remote_code=True, - # add_eos_token=True, - # add_bos_token=False - # ) - # # 257152 - # # # Decode predicted output tokens - # # # fast_logits_for_pred.argmax(dim=-1) - # def _paligemma_tokens_to_act_tokens(tokens: torch.Tensor) -> torch.Tensor: - # """ - # Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens). - # """ - # return _paligemma_tokenizer.vocab_size - 1 - 128 - tokens - # # # target = _paligemma_tokens_to_act_tokens(fast_targets) - # decoded_tokens = _paligemma_tokenizer.batch_decode(fast_targets, skip_special_tokens=False) - # decoded_tokens = [ - # _paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) - # for seq in fast_logits_for_pred.argmax(dim=-1) - # ] - # cleaned_tokens = [] - # for token_seq in decoded_tokens: - # if "|" in token_seq: - # token_seq = token_seq[:token_seq.index("|")] - # cleaned_tokens.append(token_seq) - # raw_action_tokens = [ - # torch.tensor( - # _paligemma_tokenizer.convert_tokens_to_ids(token_seq), - # dtype=torch.long, - # device=fast_targets.device, - # ) - # for token_seq in cleaned_tokens - # ] - # action_tokens = [ - # _paligemma_tokens_to_act_tokens(raw_action_token) - # for raw_action_token in raw_action_tokens - # ] - # breakpoint() - # # Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part - # cleaned_tokens = [ - # tokens_sequence.strip().split("|")[0].strip() - # for tokens_sequence in decoded_tokens - # ] - - # # Re-encode the cleaned text to get raw action tokens - # raw_action_tokens = [ - # _paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False).squeeze(0) - # for sample_tokens in cleaned_tokens - # ] - # # Convert PaliGemma tokens back to action tokens - # action_tokens = [ - # _paligemma_tokens_to_act_tokens(raw_action_token) - # for raw_action_token in raw_action_tokens - # ] - # # # Decode each sample's tokens to continuous actions - # action_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) - # # breakpoint() - # decoded_actions = action_tokenizer.decode( - # action_tokens, - # time_horizon=self.config.chunk_size, - # action_dim=6 - # ) - # breakpoint() - # def decode_actions_with_fast( - # token_ids: list[int], - # time_horizon: int, - # action_dim: int, - # relaxed_decoding: bool = False - # ) -> list: - # """ - # Decodes action token IDs back to continuous action values using the FAST tokenizer. - - # Args: - # token_ids: List of token IDs to decode. - # time_horizon: The number of timesteps for actions. - # action_dim: The dimensionality of each action. - # relaxed_decoding: Whether to use relaxed decoding (allows partial sequences). - - # Returns: - # A list representing the decoded actions. - # """ - # # Use the action tokenizer's decode method - # # The FAST tokenizer should have a decode method that converts tokens back to actions - # try: - # decoded_actions = action_tokenizer.decode( - # token_ids, - # time_horizon=time_horizon, - # action_dim=action_dim - # ) - # return decoded_actions - # except Exception as e: - # if relaxed_decoding: - # # If relaxed decoding is enabled, try to decode as much as possible - # import logging - # logging.warning(f"Relaxed decoding: {e}. Returning partial decode.") - # try: - # # Try to decode with whatever tokens we have - # partial_decoded = action_tokenizer.decode( - # token_ids[:len(token_ids)], - # time_horizon=time_horizon, - # action_dim=action_dim - # ) - # return partial_decoded - # except: - # # Return zeros if decoding completely fails - # return [[0.0] * action_dim for _ in range(time_horizon)] - # else: - # raise e - - # valid = fast_logits_for_pred.argmax(dim=-1) <= (self._paligemma_tokenizer.vocab_size - 1 - 128) - # fast_region = fast_logits_for_pred.argmax(dim=-1).masked_fill(~valid, 0) - # fast_tokens = _paligemma_tokens_to_act_tokens(fast_region) - # actions = decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.config.chunk_size, action_dim=7, relaxed_decoding=True)[0] - # breakpoint() - # decoded_actions = [ - # torch.tensor( - # decode_actions_with_fast( - # tok[0].tolist(), - # time_horizon=self.config.chunk_size, - # action_dim=7, - # relaxed_decoding=True, - # ), - # device=tokens.device, - # ).squeeze(0) - # for tok in action_tokens - # ] - # breakpoint() - # # Stack into a batch - # result = torch.stack(decoded_actions, dim=0) - # breakpoint() return { "fast_loss": fast_loss, "loss": fast_loss, diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index 63feda122..be1008d6d 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -101,7 +101,7 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): cleaned_high_level_task = cleaned_high_level_tasks[i] full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}" else: - full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" #remove Action by jade low_level_prompts.append(full_prompt) diff --git a/src/lerobot/policies/pi05/train_multi.sh b/src/lerobot/policies/pi05/train_multi.sh index fefa1b5cd..14644e0c3 100644 --- a/src/lerobot/policies/pi05/train_multi.sh +++ b/src/lerobot/policies/pi05/train_multi.sh @@ -9,21 +9,21 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \ $(which lerobot-train) \ --dataset.repo_id=local \ --dataset.root=/fsx/jade_choghari/data/libero \ - --output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \ + --output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \ --job_name=libero_training_fast \ --policy.repo_id=jade_choghari/pi05-fast-libero-8 \ --policy.path=/fsx/jade_choghari/models/libero-pi-fast \ --policy.dtype=bfloat16 \ - --steps=60000 \ - --save_freq=10000 \ + --steps=120000 \ + --save_freq=12000 \ --batch_size=8 \ --policy.compile_model=false \ --policy.device=cuda \ --policy.fast_only=true \ - --policy.scheduler_warmup_steps=2000 \ - --policy.scheduler_decay_steps=60000 \ + --policy.scheduler_warmup_steps=4000 \ + --policy.scheduler_decay_steps=120000 \ --policy.scheduler_decay_lr=1e-5 \ --policy.gradient_checkpointing=false \ --wandb.enable=true \ --wandb.disable_artifact=true \ - --wandb.project=pi05-libero-training \ \ No newline at end of file + --wandb.project=pi05-libero-training \ diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 5971ae759..d1a1893f4 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -353,7 +353,6 @@ class TokenizerProcessorStep(ObservationProcessorStep): else: # If at max length, replace the last token with EOS input_ids[i, last_token_pos] = eos_token_id - return {"input_ids": input_ids, "attention_mask": attention_mask} def get_config(self) -> dict[str, Any]: @@ -577,8 +576,13 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): # Flatten to 1D if needed if tokens.dim() > 1: tokens = tokens.flatten() + + tokens = torch.cat([ + torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device), + self._act_tokens_to_paligemma_tokens(tokens), + torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device), + ]) - tokens = torch.cat([self._act_tokens_to_paligemma_tokens(tokens), torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)]) # Truncate or pad to max_action_tokens if len(tokens) > self.max_action_tokens: import logging @@ -843,7 +847,6 @@ class ActionDetokenizerProcessorStep1(ActionProcessorStep): for raw_action_token in raw_action_tokens ] tokens = [t.flatten().tolist() for t in action_tokens] - breakpoint() # Decode each sample's tokens to continuous actions decoded_actions = [ torch.tensor( @@ -857,7 +860,6 @@ class ActionDetokenizerProcessorStep1(ActionProcessorStep): ).squeeze(0) for tok in action_tokens ] - breakpoint() # Stack into a batch result = torch.stack(decoded_actions, dim=0) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 53aa0cfcd..2c65c5a65 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -91,10 +91,10 @@ def update_policy( # Let accelerator handle mixed precision with accelerator.autocast(): loss, output_dict = policy.forward(batch) - action = policy.predict_action_chunk(batch) - if postprocessor is not None: - action = postprocessor(action) - breakpoint() + # action = policy.predict_action_chunk(batch) + # if postprocessor is not None: + # action = postprocessor(action) + # breakpoint() # TODO(rcadene): policy.unnormalize_outputs(out_dict) # Use accelerator's backward method