mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Merge branch 'feat/add-pi05' of github.com:huggingface/lerobot into feat/add-pi05
This commit is contained in:
@@ -3,7 +3,8 @@ from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/libero-10-annotate")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -13,7 +14,18 @@ dataloader = torch.utils.data.DataLoader(
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/models/pi05-base",
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
batch1 = pre_processor(batch)
|
||||
breakpoint()
|
||||
print(batch.keys())
|
||||
# print(batch['task_index_high_level'].shape)
|
||||
# print(batch['task_index_high_level'])
|
||||
|
||||
@@ -45,6 +45,7 @@ dataloader = torch.utils.data.DataLoader(
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
breakpoint()
|
||||
batch = pre_processor(batch)
|
||||
policy.train()
|
||||
# run inference
|
||||
|
||||
@@ -1717,6 +1717,7 @@ class PI05FullPolicy(PreTrainedPolicy):
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
if len(self._action_queue) == 0:
|
||||
# TODO: jadechoghari, generate subtask tokens here - ideally every 1 second
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
@@ -254,6 +254,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Add EOS token at the end of each subtask sequence (before padding)
|
||||
eos_token_id = self.input_tokenizer.eos_token_id
|
||||
input_ids = tokenized_subtask["input_ids"]
|
||||
attention_mask = tokenized_subtask["attention_mask"]
|
||||
for i in range(input_ids.size(0)):
|
||||
# Find the length of actual tokens (sum of attention mask)
|
||||
seq_len = attention_mask[i].sum().item()
|
||||
|
||||
max_len = input_ids.size(1)
|
||||
if seq_len >= max_len:
|
||||
raise ValueError(
|
||||
f"No room to append EOS: seq_len={seq_len} equals max_length={max_len}. "
|
||||
"Increase max_length or tokenize with padding=False then pad after adding EOS."
|
||||
)
|
||||
# Add EOS token at the end
|
||||
input_ids[i, seq_len] = eos_token_id
|
||||
attention_mask[i, seq_len] = 1
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
@@ -638,4 +656,4 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
return features
|
||||
Reference in New Issue
Block a user