rename and fix

This commit is contained in:
Pepijn
2025-12-13 22:27:08 +01:00
parent 522396a15a
commit 0f8aa7d03b
5 changed files with 189 additions and 245 deletions
+8 -39
View File
@@ -45,45 +45,14 @@ dataloader = torch.utils.data.DataLoader(
batch = next(iter(dataloader))
batch = pre_processor(batch)
# Test training forward pass
policy.train()
# run inference
# action = policy.select_action(batch)
loss, loss_dict = policy.forward(batch)
# import requests
# from PIL import Image
# from transformers import AutoProcessor
# model = policy.model.paligemma_with_expert.paligemma
# model = model.to(device="cuda", dtype=torch.bfloat16)
# model.eval()
# prompt = "Describe this image."
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
# image = Image.open(requests.get(url, stream=True).raw)
# processor = AutoProcessor.from_pretrained(
# "google/paligemma-3b-pt-224",
# )
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
# print("generating...")
# output = model.generate(
# **inputs,
# max_new_tokens=50,
# use_cache=True, # default dynamic cache
# )
# print(processor.decode(output[0], skip_special_tokens=True))
print(f"Training loss: {loss_dict}")
# # other model
# from transformers import PaliGemmaForConditionalGeneration
# model = PaliGemmaForConditionalGeneration.from_pretrained(
# "google/paligemma2-3b-pt-224",
# torch_dtype=torch.bfloat16,
# device_map="auto",
# )
# model.eval()
# print("generating...")
# output = model.generate(
# **inputs,
# max_new_tokens=100,
# use_cache=True, # default dynamic cache
# )
# print("Model 2 output:")
# print(processor.decode(output[0], skip_special_tokens=True))
# Test inference
policy.eval()
with torch.no_grad():
actions = policy.predict_action_chunk(batch)
print(f"Predicted actions shape: {actions.shape}")