From 8edbd5b55e5a7b39ccc151e839e6655fd868875b Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Wed, 10 Dec 2025 09:53:29 +0000 Subject: [PATCH] working step 2 --- examples/dataset/annotate_pgen.py | 169 +++++++++++++++++++++++++----- 1 file changed, 140 insertions(+), 29 deletions(-) diff --git a/examples/dataset/annotate_pgen.py b/examples/dataset/annotate_pgen.py index b5d7884ff..4b69283e1 100644 --- a/examples/dataset/annotate_pgen.py +++ b/examples/dataset/annotate_pgen.py @@ -195,7 +195,7 @@ class QwenPgen: prompt: str, ) -> dict[str, str]: """ - Call Qwen VLM to generate synthetic dialogue. + Call Qwen VLM to generate synthetic dialogue for a single request. Args: images: List of PIL Images or image paths @@ -204,34 +204,91 @@ class QwenPgen: Returns: Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance """ - # Build messages with images and text - content = [] - for img in images: - if isinstance(img, str): - content.append({"type": "image", "image": img}) - else: - # PIL Image - need to save temporarily or convert - content.append({"type": "image", "image": img}) + # Use batch method with single item + results = self.call_qwen_batch([images], [prompt]) + return results[0] + + def call_qwen_batch( + self, + batch_images: list[list[Image.Image | str]], + batch_prompts: list[str], + ) -> list[dict[str, str]]: + """ + Call Qwen VLM to generate synthetic dialogue for a batch of requests. - content.append({"type": "text", "text": prompt}) + Args: + batch_images: List of image lists, one per request + batch_prompts: List of text prompts, one per request + + Returns: + List of dictionaries, each with keys: scenario_type, response_type, user_prompt, robot_utterance + """ + if len(batch_images) != len(batch_prompts): + raise ValueError(f"Batch size mismatch: {len(batch_images)} image lists vs {len(batch_prompts)} prompts") - messages = [ - { - "role": "user", - "content": content, - } - ] + batch_size = len(batch_images) + if batch_size == 0: + return [] - # Process inputs - text = self.processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - image_inputs, video_inputs = self.process_vision_info(messages) + # Build messages for each item in batch + all_messages = [] + for images, prompt in zip(batch_images, batch_prompts): + content = [] + for img in images: + if isinstance(img, str): + content.append({"type": "image", "image": img}) + else: + # PIL Image + content.append({"type": "image", "image": img}) + + content.append({"type": "text", "text": prompt}) + + messages = [ + { + "role": "user", + "content": content, + } + ] + all_messages.append(messages) + # Process all inputs + texts = [] + all_image_inputs = [] + all_video_inputs = [] + + for messages in all_messages: + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + texts.append(text) + + image_inputs, video_inputs = self.process_vision_info(messages) + all_image_inputs.append(image_inputs) + all_video_inputs.append(video_inputs) + + # Flatten image and video inputs for batch processing + # The processor expects a flat list of images across all batch items + flat_images = [] + for img_list in all_image_inputs: + if img_list is not None: + if isinstance(img_list, list): + flat_images.extend(img_list) + else: + flat_images.append(img_list) + + flat_videos = [] + for vid_list in all_video_inputs: + if vid_list is not None: + if isinstance(vid_list, list): + flat_videos.extend(vid_list) + else: + flat_videos.append(vid_list) + + # Process batch inputs = self.processor( - text=[text], - images=image_inputs, - videos=video_inputs, + text=texts, + images=flat_images if flat_images else None, + videos=flat_videos if flat_videos else None, padding=True, return_tensors="pt", ).to(self.device) @@ -245,13 +302,29 @@ class QwenPgen: temperature=self.temperature, ) - # Decode response - response = self.processor.batch_decode( + # Decode responses + responses = self.processor.batch_decode( [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)], skip_special_tokens=True, - )[0].strip() + ) - return self._parse_response(response) + # Parse all responses + results = [] + for response in responses: + try: + parsed = self._parse_response(response.strip()) + results.append(parsed) + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + # Return empty/default result + results.append({ + "scenario_type": "specific_object", + "response_type": "confirmation", + "user_prompt": "", + "robot_utterance": "", + }) + + return results def _parse_response(self, response: str) -> dict[str, str]: """Parse JSON response from model.""" @@ -333,6 +406,39 @@ def annotate_sample( return result +def annotate_samples_batch( + pgen: QwenPgen, + batch_images: list[list[Image.Image | str]], + batch_task_descriptions: list[str], + batch_skill_histories: list[list[str]], + batch_skill_currents: list[str], +) -> list[dict[str, str]]: + """ + Generate synthetic dialogue for a batch of samples. + + Args: + pgen: Qwen model wrapper + batch_images: List of image lists, one per sample + batch_task_descriptions: List of task descriptions + batch_skill_histories: List of skill history lists + batch_skill_currents: List of current skills + + Returns: + List of dictionaries with generated dialogue + """ + # Construct prompts for entire batch + batch_prompts = [] + for task_desc, skill_hist, skill_curr in zip( + batch_task_descriptions, batch_skill_histories, batch_skill_currents + ): + prompt = construct_prompt(task_desc, skill_hist, skill_curr) + batch_prompts.append(prompt) + + # Process entire batch in one call + results = pgen.call_qwen_batch(batch_images, batch_prompts) + return results + + def generate_synthetic_data( dataset: LeRobotDataset, pgen: QwenPgen, @@ -733,6 +839,11 @@ def main(): output_dir=output_dir, repo_id=repo_id, ) + + # copy high level tsk parquet to new output directory + import shutil + shutil.copy(dataset_root / "meta" / "tasks_high_level.parquet", output_dir / "meta" / "tasks_high_level.parquet") + shutil.copy(dataset_root / "meta" / "syn_annotations.jsonl", output_dir / "meta" / "syn_annotations.jsonl") console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]") console.print(f" New dataset saved to: {new_dataset.root}") @@ -745,7 +856,7 @@ def main(): else: console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]") try: - new_dataset.push_to_hub(push_videos=False) + new_dataset.push_to_hub() console.print(f"[green]✓ Pushed to {repo_id}[/green]") except Exception as e: console.print(f"[red]Push failed: {e}[/red]")