working step 2

This commit is contained in:
Jade Choghari
2025-12-10 09:53:29 +00:00
parent 025c2b2831
commit 8edbd5b55e
+140 -29
View File
@@ -195,7 +195,7 @@ class QwenPgen:
prompt: str,
) -> dict[str, str]:
"""
Call Qwen VLM to generate synthetic dialogue.
Call Qwen VLM to generate synthetic dialogue for a single request.
Args:
images: List of PIL Images or image paths
@@ -204,34 +204,91 @@ class QwenPgen:
Returns:
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
"""
# Build messages with images and text
content = []
for img in images:
if isinstance(img, str):
content.append({"type": "image", "image": img})
else:
# PIL Image - need to save temporarily or convert
content.append({"type": "image", "image": img})
# Use batch method with single item
results = self.call_qwen_batch([images], [prompt])
return results[0]
def call_qwen_batch(
self,
batch_images: list[list[Image.Image | str]],
batch_prompts: list[str],
) -> list[dict[str, str]]:
"""
Call Qwen VLM to generate synthetic dialogue for a batch of requests.
content.append({"type": "text", "text": prompt})
Args:
batch_images: List of image lists, one per request
batch_prompts: List of text prompts, one per request
Returns:
List of dictionaries, each with keys: scenario_type, response_type, user_prompt, robot_utterance
"""
if len(batch_images) != len(batch_prompts):
raise ValueError(f"Batch size mismatch: {len(batch_images)} image lists vs {len(batch_prompts)} prompts")
messages = [
{
"role": "user",
"content": content,
}
]
batch_size = len(batch_images)
if batch_size == 0:
return []
# Process inputs
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = self.process_vision_info(messages)
# Build messages for each item in batch
all_messages = []
for images, prompt in zip(batch_images, batch_prompts):
content = []
for img in images:
if isinstance(img, str):
content.append({"type": "image", "image": img})
else:
# PIL Image
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": prompt})
messages = [
{
"role": "user",
"content": content,
}
]
all_messages.append(messages)
# Process all inputs
texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
texts.append(text)
image_inputs, video_inputs = self.process_vision_info(messages)
all_image_inputs.append(image_inputs)
all_video_inputs.append(video_inputs)
# Flatten image and video inputs for batch processing
# The processor expects a flat list of images across all batch items
flat_images = []
for img_list in all_image_inputs:
if img_list is not None:
if isinstance(img_list, list):
flat_images.extend(img_list)
else:
flat_images.append(img_list)
flat_videos = []
for vid_list in all_video_inputs:
if vid_list is not None:
if isinstance(vid_list, list):
flat_videos.extend(vid_list)
else:
flat_videos.append(vid_list)
# Process batch
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
text=texts,
images=flat_images if flat_images else None,
videos=flat_videos if flat_videos else None,
padding=True,
return_tensors="pt",
).to(self.device)
@@ -245,13 +302,29 @@ class QwenPgen:
temperature=self.temperature,
)
# Decode response
response = self.processor.batch_decode(
# Decode responses
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)[0].strip()
)
return self._parse_response(response)
# Parse all responses
results = []
for response in responses:
try:
parsed = self._parse_response(response.strip())
results.append(parsed)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
# Return empty/default result
results.append({
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "",
"robot_utterance": "",
})
return results
def _parse_response(self, response: str) -> dict[str, str]:
"""Parse JSON response from model."""
@@ -333,6 +406,39 @@ def annotate_sample(
return result
def annotate_samples_batch(
pgen: QwenPgen,
batch_images: list[list[Image.Image | str]],
batch_task_descriptions: list[str],
batch_skill_histories: list[list[str]],
batch_skill_currents: list[str],
) -> list[dict[str, str]]:
"""
Generate synthetic dialogue for a batch of samples.
Args:
pgen: Qwen model wrapper
batch_images: List of image lists, one per sample
batch_task_descriptions: List of task descriptions
batch_skill_histories: List of skill history lists
batch_skill_currents: List of current skills
Returns:
List of dictionaries with generated dialogue
"""
# Construct prompts for entire batch
batch_prompts = []
for task_desc, skill_hist, skill_curr in zip(
batch_task_descriptions, batch_skill_histories, batch_skill_currents
):
prompt = construct_prompt(task_desc, skill_hist, skill_curr)
batch_prompts.append(prompt)
# Process entire batch in one call
results = pgen.call_qwen_batch(batch_images, batch_prompts)
return results
def generate_synthetic_data(
dataset: LeRobotDataset,
pgen: QwenPgen,
@@ -733,6 +839,11 @@ def main():
output_dir=output_dir,
repo_id=repo_id,
)
# copy high level tsk parquet to new output directory
import shutil
shutil.copy(dataset_root / "meta" / "tasks_high_level.parquet", output_dir / "meta" / "tasks_high_level.parquet")
shutil.copy(dataset_root / "meta" / "syn_annotations.jsonl", output_dir / "meta" / "syn_annotations.jsonl")
console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]")
console.print(f" New dataset saved to: {new_dataset.root}")
@@ -745,7 +856,7 @@ def main():
else:
console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]")
try:
new_dataset.push_to_hub(push_videos=False)
new_dataset.push_to_hub()
console.print(f"[green]✓ Pushed to {repo_id}[/green]")
except Exception as e:
console.print(f"[red]Push failed: {e}[/red]")