mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 16:09:44 +00:00
add tests/fixes
This commit is contained in:
@@ -10,7 +10,7 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# --------------- Configuration ---------------
|
# --------------- Configuration ---------------
|
||||||
REPO_ID="${REPO_ID:-lerobot-data-collection/round1_1}"
|
REPO_ID="${REPO_ID:-lerobot-data-collection/round1_4}"
|
||||||
# MODEL="${MODEL:-Qwen/Qwen3-VL-30B-A3B-Thinking}"
|
# MODEL="${MODEL:-Qwen/Qwen3-VL-30B-A3B-Thinking}"
|
||||||
MODEL="${MODEL:-Qwen/Qwen3.5-27B}"
|
MODEL="${MODEL:-Qwen/Qwen3.5-27B}"
|
||||||
# Or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
# Or: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Data annotations for subtasks and VLM-based labeling.
|
||||||
@@ -5,13 +5,12 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from lerobot.datasets.dataset_tools import add_features
|
from lerobot.datasets.dataset_tools import add_features
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
create_subtasks_dataframe,
|
|
||||||
create_subtask_index_array,
|
create_subtask_index_array,
|
||||||
|
create_subtasks_dataframe,
|
||||||
save_subtasks,
|
save_subtasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -57,6 +56,7 @@ class EpisodeSkills:
|
|||||||
|
|
||||||
# Video Extraction Utilities
|
# Video Extraction Utilities
|
||||||
|
|
||||||
|
|
||||||
class VideoExtractor:
|
class VideoExtractor:
|
||||||
"""Utilities for extracting and processing video segments from LeRobot datasets."""
|
"""Utilities for extracting and processing video segments from LeRobot datasets."""
|
||||||
|
|
||||||
@@ -82,9 +82,8 @@ class VideoExtractor:
|
|||||||
Returns:
|
Returns:
|
||||||
Path to the extracted temporary video file
|
Path to the extracted temporary video file
|
||||||
"""
|
"""
|
||||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||||
tmp_path = Path(tmp_file.name)
|
tmp_path = Path(tmp_file.name)
|
||||||
tmp_file.close()
|
|
||||||
|
|
||||||
duration = end_timestamp - start_timestamp
|
duration = end_timestamp - start_timestamp
|
||||||
|
|
||||||
@@ -115,8 +114,8 @@ class VideoExtractor:
|
|||||||
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
raise RuntimeError(f"FFmpeg failed: {e}") from e
|
raise RuntimeError(f"FFmpeg failed: {e}") from e
|
||||||
except FileNotFoundError:
|
except FileNotFoundError as e:
|
||||||
raise RuntimeError("FFmpeg not found. Please install ffmpeg.")
|
raise RuntimeError("FFmpeg not found. Please install ffmpeg.") from e
|
||||||
|
|
||||||
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
|
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
|
||||||
if tmp_path.exists():
|
if tmp_path.exists():
|
||||||
@@ -131,9 +130,8 @@ class VideoExtractor:
|
|||||||
Used so the VLM can read the timestamp from the image instead of relying on file metadata.
|
Used so the VLM can read the timestamp from the image instead of relying on file metadata.
|
||||||
Draws a black box with white text at top-right. Writes to a new temporary file and returns its path.
|
Draws a black box with white text at top-right. Writes to a new temporary file and returns its path.
|
||||||
"""
|
"""
|
||||||
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file:
|
||||||
out_path = Path(out_file.name)
|
out_path = Path(out_file.name)
|
||||||
out_file.close()
|
|
||||||
|
|
||||||
cap = cv2.VideoCapture(str(video_path))
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
if not cap.isOpened():
|
if not cap.isOpened():
|
||||||
@@ -289,7 +287,9 @@ class SkillAnnotator:
|
|||||||
batch_end = min(batch_start + self.batch_size, len(episode_indices))
|
batch_end = min(batch_start + self.batch_size, len(episode_indices))
|
||||||
batch_episodes = episode_indices[batch_start:batch_end]
|
batch_episodes = episode_indices[batch_start:batch_end]
|
||||||
|
|
||||||
print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...")
|
print(
|
||||||
|
f"Processing batch {batch_start // self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1) // self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})..."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
batch_annotations = self._annotate_episodes_batch(
|
batch_annotations = self._annotate_episodes_batch(
|
||||||
@@ -337,9 +337,7 @@ class SkillAnnotator:
|
|||||||
for ep_idx, error_msg in list(failed_episodes.items()):
|
for ep_idx, error_msg in list(failed_episodes.items()):
|
||||||
print(f"Retry attempt for episode {ep_idx} (previous error: {error_msg})")
|
print(f"Retry attempt for episode {ep_idx} (previous error: {error_msg})")
|
||||||
try:
|
try:
|
||||||
skills = self._annotate_episode(
|
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal, subtask_labels)
|
||||||
dataset, ep_idx, video_key, coarse_goal, subtask_labels
|
|
||||||
)
|
|
||||||
if skills:
|
if skills:
|
||||||
annotations[ep_idx] = EpisodeSkills(
|
annotations[ep_idx] = EpisodeSkills(
|
||||||
episode_index=ep_idx,
|
episode_index=ep_idx,
|
||||||
@@ -434,13 +432,11 @@ class SkillAnnotator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Run VLM skill segmentation in batch
|
# Run VLM skill segmentation in batch
|
||||||
all_skills = self.vlm.segment_skills_batch(
|
all_skills = self.vlm.segment_skills_batch(paths_for_vlm, durations, coarse_goal, subtask_labels)
|
||||||
paths_for_vlm, durations, coarse_goal, subtask_labels
|
|
||||||
)
|
|
||||||
|
|
||||||
# Map results back to episode indices
|
# Map results back to episode indices
|
||||||
results = {}
|
results = {}
|
||||||
for ep_idx, skills in zip(valid_episode_indices, all_skills):
|
for ep_idx, skills in zip(valid_episode_indices, all_skills, strict=True):
|
||||||
results[ep_idx] = skills
|
results[ep_idx] = skills
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -486,9 +482,7 @@ class SkillAnnotator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Run VLM skill segmentation
|
# Run VLM skill segmentation
|
||||||
skills = self.vlm.segment_skills(
|
skills = self.vlm.segment_skills(video_for_vlm, duration, coarse_goal, subtask_labels)
|
||||||
video_for_vlm, duration, coarse_goal, subtask_labels
|
|
||||||
)
|
|
||||||
return skills
|
return skills
|
||||||
finally:
|
finally:
|
||||||
# Clean up temporary files (extracted and optionally timer-overlay)
|
# Clean up temporary files (extracted and optionally timer-overlay)
|
||||||
@@ -568,11 +562,13 @@ def save_skill_annotations(
|
|||||||
existing_skills_data = None
|
existing_skills_data = None
|
||||||
if skills_path.exists():
|
if skills_path.exists():
|
||||||
try:
|
try:
|
||||||
with open(skills_path, "r") as f:
|
with open(skills_path) as f:
|
||||||
existing_skills_data = json.load(f)
|
existing_skills_data = json.load(f)
|
||||||
if existing_skills_data and len(existing_skills_data.get("episodes", {})) > 0:
|
if existing_skills_data and len(existing_skills_data.get("episodes", {})) > 0:
|
||||||
print(f"Found existing skills.json with {len(existing_skills_data.get('episodes', {}))} episodes, merging...")
|
print(
|
||||||
except (json.JSONDecodeError, IOError):
|
f"Found existing skills.json with {len(existing_skills_data.get('episodes', {}))} episodes, merging..."
|
||||||
|
)
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
print("Warning: Could not load existing skills.json, will create new file")
|
print("Warning: Could not load existing skills.json, will create new file")
|
||||||
existing_skills_data = None
|
existing_skills_data = None
|
||||||
|
|
||||||
@@ -590,14 +586,18 @@ def save_skill_annotations(
|
|||||||
merged_skill_to_subtask.update(skill_to_subtask_idx)
|
merged_skill_to_subtask.update(skill_to_subtask_idx)
|
||||||
|
|
||||||
# Use existing coarse_description if available, otherwise use new one
|
# Use existing coarse_description if available, otherwise use new one
|
||||||
coarse_desc = existing_skills_data.get("coarse_description", annotations[next(iter(annotations))].description)
|
coarse_desc = existing_skills_data.get(
|
||||||
|
"coarse_description", annotations[next(iter(annotations))].description
|
||||||
|
)
|
||||||
|
|
||||||
skills_data = {
|
skills_data = {
|
||||||
"coarse_description": coarse_desc,
|
"coarse_description": coarse_desc,
|
||||||
"skill_to_subtask_index": merged_skill_to_subtask,
|
"skill_to_subtask_index": merged_skill_to_subtask,
|
||||||
"episodes": merged_episodes,
|
"episodes": merged_episodes,
|
||||||
}
|
}
|
||||||
print(f"Updated {len(new_episodes)} episode(s), total episodes in skills.json: {len(merged_episodes)}")
|
print(
|
||||||
|
f"Updated {len(new_episodes)} episode(s), total episodes in skills.json: {len(merged_episodes)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# No existing data, create new
|
# No existing data, create new
|
||||||
skills_data = {
|
skills_data = {
|
||||||
@@ -615,10 +615,7 @@ def save_skill_annotations(
|
|||||||
print("Adding subtask_index feature to dataset...")
|
print("Adding subtask_index feature to dataset...")
|
||||||
|
|
||||||
# Determine output directory and repo_id
|
# Determine output directory and repo_id
|
||||||
if output_dir is None:
|
output_dir = dataset.root.parent / f"{dataset.root.name}" if output_dir is None else Path(output_dir)
|
||||||
output_dir = dataset.root.parent / f"{dataset.root.name}"
|
|
||||||
else:
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
|
|
||||||
if repo_id is None:
|
if repo_id is None:
|
||||||
repo_id = f"{dataset.repo_id}"
|
repo_id = f"{dataset.repo_id}"
|
||||||
@@ -640,14 +637,9 @@ def save_skill_annotations(
|
|||||||
|
|
||||||
# Copy subtasks.parquet to new output directory
|
# Copy subtasks.parquet to new output directory
|
||||||
import shutil
|
import shutil
|
||||||
shutil.copy(
|
|
||||||
dataset.root / "meta" / "subtasks.parquet",
|
shutil.copy(dataset.root / "meta" / "subtasks.parquet", output_dir / "meta" / "subtasks.parquet")
|
||||||
output_dir / "meta" / "subtasks.parquet"
|
shutil.copy(dataset.root / "meta" / "skills.json", output_dir / "meta" / "skills.json")
|
||||||
)
|
|
||||||
shutil.copy(
|
|
||||||
dataset.root / "meta" / "skills.json",
|
|
||||||
output_dir / "meta" / "skills.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(" Successfully added subtask_index feature!")
|
print(" Successfully added subtask_index feature!")
|
||||||
print(f" New dataset saved to: {new_dataset.root}")
|
print(f" New dataset saved to: {new_dataset.root}")
|
||||||
|
|||||||
@@ -7,13 +7,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
SKILL_SEGMENTATION_PROMPT_TEMPLATE,
|
SKILL_SEGMENTATION_PROMPT_TEMPLATE,
|
||||||
format_subtask_labels_section,
|
format_subtask_labels_section,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
|
|
||||||
|
|
||||||
|
|
||||||
class BaseVLM(ABC):
|
class BaseVLM(ABC):
|
||||||
"""
|
"""
|
||||||
@@ -85,9 +84,7 @@ def create_skill_segmentation_prompt(
|
|||||||
if duration_seconds is None:
|
if duration_seconds is None:
|
||||||
raise ValueError("duration_seconds is required for skill segmentation prompt")
|
raise ValueError("duration_seconds is required for skill segmentation prompt")
|
||||||
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
|
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
|
||||||
subtask_labels_section = (
|
subtask_labels_section = format_subtask_labels_section(subtask_labels) if subtask_labels else ""
|
||||||
format_subtask_labels_section(subtask_labels) if subtask_labels else ""
|
|
||||||
)
|
|
||||||
video_duration_mm_ss = f"{int(duration_seconds // 60):02d}:{int(duration_seconds % 60):02d}"
|
video_duration_mm_ss = f"{int(duration_seconds // 60):02d}:{int(duration_seconds % 60):02d}"
|
||||||
return SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
|
return SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
|
||||||
goal_context=goal_context,
|
goal_context=goal_context,
|
||||||
@@ -99,6 +96,7 @@ def create_skill_segmentation_prompt(
|
|||||||
|
|
||||||
# Qwen2-VL Implementation
|
# Qwen2-VL Implementation
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VL(BaseVLM):
|
class Qwen2VL(BaseVLM):
|
||||||
"""Qwen2-VL model for skill segmentation."""
|
"""Qwen2-VL model for skill segmentation."""
|
||||||
|
|
||||||
@@ -157,10 +155,12 @@ class Qwen2VL(BaseVLM):
|
|||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
generated_ids = self.model.generate(
|
||||||
|
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
response = self.processor.batch_decode(
|
response = self.processor.batch_decode(
|
||||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)[0].strip()
|
)[0].strip()
|
||||||
|
|
||||||
@@ -176,10 +176,8 @@ class Qwen2VL(BaseVLM):
|
|||||||
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||||
# Create messages for each video (prompt includes duration so each gets correct length)
|
# Create messages for each video (prompt includes duration so each gets correct length)
|
||||||
all_messages = []
|
all_messages = []
|
||||||
for video_path, duration in zip(video_paths, episode_durations):
|
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||||
prompt = create_skill_segmentation_prompt(
|
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||||
coarse_goal, subtask_labels, duration_seconds=duration
|
|
||||||
)
|
|
||||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||||
@@ -217,10 +215,12 @@ class Qwen2VL(BaseVLM):
|
|||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
generated_ids = self.model.generate(
|
||||||
|
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
responses = self.processor.batch_decode(
|
responses = self.processor.batch_decode(
|
||||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -321,10 +321,12 @@ class Qwen3VL(BaseVLM):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
generated_ids = self.model.generate(
|
||||||
|
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
response = self.processor.batch_decode(
|
response = self.processor.batch_decode(
|
||||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)[0].strip()
|
)[0].strip()
|
||||||
|
|
||||||
@@ -340,10 +342,8 @@ class Qwen3VL(BaseVLM):
|
|||||||
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||||
# Create messages for each video (prompt includes duration so each gets correct length)
|
# Create messages for each video (prompt includes duration so each gets correct length)
|
||||||
all_messages = []
|
all_messages = []
|
||||||
for video_path, duration in zip(video_paths, episode_durations):
|
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||||
prompt = create_skill_segmentation_prompt(
|
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||||
coarse_goal, subtask_labels, duration_seconds=duration
|
|
||||||
)
|
|
||||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||||
@@ -381,10 +381,12 @@ class Qwen3VL(BaseVLM):
|
|||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
generated_ids = self.model.generate(
|
||||||
|
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
responses = self.processor.batch_decode(
|
responses = self.processor.batch_decode(
|
||||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -427,7 +429,7 @@ class Qwen3VL(BaseVLM):
|
|||||||
# Qwen3.5-VL Implementation (Qwen3_5ForConditionalGeneration)
|
# Qwen3.5-VL Implementation (Qwen3_5ForConditionalGeneration)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3_5VL(BaseVLM):
|
class Qwen35VL(BaseVLM):
|
||||||
"""Qwen3.5-VL model for skill segmentation (Qwen3_5ForConditionalGeneration)."""
|
"""Qwen3.5-VL model for skill segmentation (Qwen3_5ForConditionalGeneration)."""
|
||||||
|
|
||||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||||
@@ -486,7 +488,7 @@ class Qwen3_5VL(BaseVLM):
|
|||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
||||||
|
|
||||||
response = self.processor.batch_decode(
|
response = self.processor.batch_decode(
|
||||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
)[0].strip()
|
)[0].strip()
|
||||||
@@ -502,10 +504,8 @@ class Qwen3_5VL(BaseVLM):
|
|||||||
) -> list[list[Skill]]:
|
) -> list[list[Skill]]:
|
||||||
"""Segment multiple videos into skills using Qwen3.5-VL in a batch."""
|
"""Segment multiple videos into skills using Qwen3.5-VL in a batch."""
|
||||||
all_messages = []
|
all_messages = []
|
||||||
for video_path, duration in zip(video_paths, episode_durations):
|
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||||
prompt = create_skill_segmentation_prompt(
|
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||||
coarse_goal, subtask_labels, duration_seconds=duration
|
|
||||||
)
|
|
||||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||||
@@ -527,7 +527,9 @@ class Qwen3_5VL(BaseVLM):
|
|||||||
all_video_inputs = []
|
all_video_inputs = []
|
||||||
|
|
||||||
for messages in all_messages:
|
for messages in all_messages:
|
||||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
||||||
|
)
|
||||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||||
all_texts.append(text)
|
all_texts.append(text)
|
||||||
all_image_inputs.extend(image_inputs or [])
|
all_image_inputs.extend(image_inputs or [])
|
||||||
@@ -545,7 +547,7 @@ class Qwen3_5VL(BaseVLM):
|
|||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
||||||
|
|
||||||
responses = self.processor.batch_decode(
|
responses = self.processor.batch_decode(
|
||||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
)
|
)
|
||||||
@@ -584,6 +586,7 @@ class Qwen3_5VL(BaseVLM):
|
|||||||
|
|
||||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||||
|
|
||||||
|
|
||||||
# VLM Registry - Add new VLMs here
|
# VLM Registry - Add new VLMs here
|
||||||
|
|
||||||
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
||||||
@@ -594,8 +597,8 @@ VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
|||||||
# Qwen3-VL variants (MoE)
|
# Qwen3-VL variants (MoE)
|
||||||
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
|
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
|
||||||
# Qwen3.5-VL (Qwen3_5ForConditionalGeneration)
|
# Qwen3.5-VL (Qwen3_5ForConditionalGeneration)
|
||||||
"Qwen/Qwen3.5-27B": Qwen3_5VL,
|
"Qwen/Qwen3.5-27B": Qwen35VL,
|
||||||
"Qwen/Qwen3-VL-8B-Instruct": Qwen3_5VL,
|
"Qwen/Qwen3-VL-8B-Instruct": Qwen35VL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -621,7 +624,7 @@ def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = to
|
|||||||
# Check for partial matches (e.g., "qwen2" in model name)
|
# Check for partial matches (e.g., "qwen2" in model name)
|
||||||
model_lower = model_name.lower()
|
model_lower = model_name.lower()
|
||||||
if "qwen3.5" in model_lower:
|
if "qwen3.5" in model_lower:
|
||||||
return Qwen3_5VL(model_name, device, torch_dtype)
|
return Qwen35VL(model_name, device, torch_dtype)
|
||||||
if "qwen3" in model_lower:
|
if "qwen3" in model_lower:
|
||||||
return Qwen3VL(model_name, device, torch_dtype)
|
return Qwen3VL(model_name, device, torch_dtype)
|
||||||
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
|
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
|
||||||
|
|||||||
@@ -1220,8 +1220,9 @@ def find_float_index(target, float_list, threshold=1e-6):
|
|||||||
return i
|
return i
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
|
||||||
def create_subtasks_dataframe(
|
def create_subtasks_dataframe(
|
||||||
annotations: "dict[int, EpisodeSkills]",
|
annotations: dict[int, EpisodeSkills],
|
||||||
) -> tuple[pd.DataFrame, dict[str, int]]:
|
) -> tuple[pd.DataFrame, dict[str, int]]:
|
||||||
"""
|
"""
|
||||||
Create a subtasks DataFrame from skill annotations.
|
Create a subtasks DataFrame from skill annotations.
|
||||||
@@ -1237,23 +1238,24 @@ def create_subtasks_dataframe(
|
|||||||
for episode_skills in annotations.values():
|
for episode_skills in annotations.values():
|
||||||
for skill in episode_skills.skills:
|
for skill in episode_skills.skills:
|
||||||
all_skill_names.add(skill.name)
|
all_skill_names.add(skill.name)
|
||||||
|
|
||||||
print(f"Found {len(all_skill_names)} unique subtasks")
|
|
||||||
|
|
||||||
# Build subtasks DataFrame
|
# Build subtasks DataFrame
|
||||||
subtask_data = []
|
subtask_data = []
|
||||||
for i, skill_name in enumerate(sorted(all_skill_names)):
|
for i, skill_name in enumerate(sorted(all_skill_names)):
|
||||||
subtask_data.append({
|
subtask_data.append(
|
||||||
|
{
|
||||||
"subtask": skill_name,
|
"subtask": skill_name,
|
||||||
"subtask_index": i,
|
"subtask_index": i,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not subtask_data:
|
||||||
|
subtasks_df = pd.DataFrame(columns=["subtask", "subtask_index"]).set_index("subtask")
|
||||||
|
else:
|
||||||
subtasks_df = pd.DataFrame(subtask_data).set_index("subtask")
|
subtasks_df = pd.DataFrame(subtask_data).set_index("subtask")
|
||||||
|
|
||||||
# Build skill name to subtask_index mapping
|
# Build skill name to subtask_index mapping
|
||||||
skill_to_subtask_idx = {
|
skill_to_subtask_idx = {
|
||||||
skill_name: int(subtasks_df.loc[skill_name, "subtask_index"])
|
skill_name: int(subtasks_df.loc[skill_name, "subtask_index"]) for skill_name in all_skill_names
|
||||||
for skill_name in all_skill_names
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return subtasks_df, skill_to_subtask_idx
|
return subtasks_df, skill_to_subtask_idx
|
||||||
@@ -1268,12 +1270,11 @@ def save_subtasks(
|
|||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
subtasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
|
subtasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
|
||||||
print(f" Saved subtasks to {output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_subtask_index_array(
|
def create_subtask_index_array(
|
||||||
dataset: "LeRobotDataset",
|
dataset: LeRobotDataset,
|
||||||
annotations: "dict[int, EpisodeSkills]",
|
annotations: dict[int, EpisodeSkills],
|
||||||
skill_to_subtask_idx: dict[str, int],
|
skill_to_subtask_idx: dict[str, int],
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -1292,8 +1293,6 @@ def create_subtask_index_array(
|
|||||||
full_dataset_length = len(dataset)
|
full_dataset_length = len(dataset)
|
||||||
subtask_indices = np.full(full_dataset_length, -1, dtype=np.int64)
|
subtask_indices = np.full(full_dataset_length, -1, dtype=np.int64)
|
||||||
|
|
||||||
print(f"Creating subtask_index array for {full_dataset_length} frames...")
|
|
||||||
|
|
||||||
# Assign subtask_index for each annotated episode
|
# Assign subtask_index for each annotated episode
|
||||||
fps = float(dataset.meta.fps)
|
fps = float(dataset.meta.fps)
|
||||||
for ep_idx, episode_skills in annotations.items():
|
for ep_idx, episode_skills in annotations.items():
|
||||||
@@ -1324,7 +1323,6 @@ def create_subtask_index_array(
|
|||||||
subtask_idx = skill_to_subtask_idx[skill.name]
|
subtask_idx = skill_to_subtask_idx[skill.name]
|
||||||
subtask_indices[frame_idx] = subtask_idx
|
subtask_indices[frame_idx] = subtask_idx
|
||||||
|
|
||||||
print(" Created subtask_index array")
|
|
||||||
return subtask_indices
|
return subtask_indices
|
||||||
|
|
||||||
|
|
||||||
@@ -1391,7 +1389,7 @@ class Backtrackable[T]:
|
|||||||
self._history = history
|
self._history = history
|
||||||
self._lookahead = lookahead
|
self._lookahead = lookahead
|
||||||
|
|
||||||
def __iter__(self) -> "Backtrackable[T]":
|
def __iter__(self) -> Backtrackable[T]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self) -> T:
|
def __next__(self) -> T:
|
||||||
|
|||||||
@@ -70,8 +70,8 @@ class SubtaskAnnotateConfig:
|
|||||||
output_dir: str | None = None
|
output_dir: str | None = None
|
||||||
output_repo_id: str | None = None
|
output_repo_id: str | None = None
|
||||||
push_to_hub: bool = False
|
push_to_hub: bool = False
|
||||||
# Closed vocabulary: model must choose only from these labels
|
# Closed vocabulary: comma-separated labels (e.g. "label1,label2,label3")
|
||||||
subtask_labels: list[str] | None = None
|
subtask_labels: str | None = None
|
||||||
# Disable timer overlay on video (by default a timer is drawn for the VLM)
|
# Disable timer overlay on video (by default a timer is drawn for the VLM)
|
||||||
no_timer_overlay: bool = False
|
no_timer_overlay: bool = False
|
||||||
|
|
||||||
@@ -87,6 +87,11 @@ def subtask_annotate(cfg: SubtaskAnnotateConfig):
|
|||||||
if (cfg.data_dir is None) == (cfg.repo_id is None):
|
if (cfg.data_dir is None) == (cfg.repo_id is None):
|
||||||
raise ValueError("Provide exactly one of --data_dir or --repo_id")
|
raise ValueError("Provide exactly one of --data_dir or --repo_id")
|
||||||
|
|
||||||
|
# Parse comma-separated subtask labels into a list (or None)
|
||||||
|
subtask_labels_list: list[str] | None = None
|
||||||
|
if cfg.subtask_labels and cfg.subtask_labels.strip():
|
||||||
|
subtask_labels_list = [s.strip() for s in cfg.subtask_labels.split(",") if s.strip()]
|
||||||
|
|
||||||
dtype_map = {
|
dtype_map = {
|
||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
@@ -96,9 +101,7 @@ def subtask_annotate(cfg: SubtaskAnnotateConfig):
|
|||||||
|
|
||||||
print("Loading dataset...")
|
print("Loading dataset...")
|
||||||
if cfg.data_dir:
|
if cfg.data_dir:
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(repo_id="local/dataset", root=cfg.data_dir, download_videos=False)
|
||||||
repo_id="local/dataset", root=cfg.data_dir, download_videos=False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
dataset = LeRobotDataset(repo_id=cfg.repo_id, download_videos=True)
|
dataset = LeRobotDataset(repo_id=cfg.repo_id, download_videos=True)
|
||||||
|
|
||||||
@@ -106,9 +109,7 @@ def subtask_annotate(cfg: SubtaskAnnotateConfig):
|
|||||||
|
|
||||||
if cfg.video_key not in dataset.meta.video_keys:
|
if cfg.video_key not in dataset.meta.video_keys:
|
||||||
available = ", ".join(dataset.meta.video_keys)
|
available = ", ".join(dataset.meta.video_keys)
|
||||||
raise ValueError(
|
raise ValueError(f"Video key '{cfg.video_key}' not found. Available: {available}")
|
||||||
f"Video key '{cfg.video_key}' not found. Available: {available}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Initializing VLM: {cfg.model}...")
|
print(f"Initializing VLM: {cfg.model}...")
|
||||||
vlm = get_vlm(cfg.model, cfg.device, torch_dtype)
|
vlm = get_vlm(cfg.model, cfg.device, torch_dtype)
|
||||||
@@ -125,14 +126,12 @@ def subtask_annotate(cfg: SubtaskAnnotateConfig):
|
|||||||
video_key=cfg.video_key,
|
video_key=cfg.video_key,
|
||||||
episodes=cfg.episodes,
|
episodes=cfg.episodes,
|
||||||
skip_existing=cfg.skip_existing,
|
skip_existing=cfg.skip_existing,
|
||||||
subtask_labels=cfg.subtask_labels,
|
subtask_labels=subtask_labels_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_dir = Path(cfg.output_dir) if cfg.output_dir else None
|
output_dir = Path(cfg.output_dir) if cfg.output_dir else None
|
||||||
output_repo_id = cfg.output_repo_id
|
output_repo_id = cfg.output_repo_id
|
||||||
new_dataset = save_skill_annotations(
|
new_dataset = save_skill_annotations(dataset, annotations, output_dir, output_repo_id)
|
||||||
dataset, annotations, output_dir, output_repo_id
|
|
||||||
)
|
|
||||||
|
|
||||||
total_skills = sum(len(ann.skills) for ann in annotations.values())
|
total_skills = sum(len(ann.skills) for ann in annotations.values())
|
||||||
print("\nAnnotation complete!")
|
print("\nAnnotation complete!")
|
||||||
|
|||||||
@@ -23,11 +23,18 @@ These tests verify that:
|
|||||||
- Subtask handling gracefully handles missing data
|
- Subtask handling gracefully handles missing data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.data_processing.data_annotations.subtask_annotations import EpisodeSkills, Skill
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.utils import (
|
||||||
|
create_subtask_index_array,
|
||||||
|
create_subtasks_dataframe,
|
||||||
|
save_subtasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSubtaskDataset:
|
class TestSubtaskDataset:
|
||||||
@@ -188,3 +195,164 @@ class TestSubtaskEdgeCases:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
subtask_map[idx] = subtask
|
subtask_map[idx] = subtask
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSubtasksDataframe:
|
||||||
|
"""Tests for create_subtasks_dataframe in utils."""
|
||||||
|
|
||||||
|
def test_empty_annotations(self):
|
||||||
|
"""Empty annotations produce empty DataFrame and empty mapping."""
|
||||||
|
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe({})
|
||||||
|
assert len(subtasks_df) == 0
|
||||||
|
assert list(subtasks_df.columns) == ["subtask_index"]
|
||||||
|
assert skill_to_subtask_idx == {}
|
||||||
|
|
||||||
|
def test_single_episode_single_skill(self):
|
||||||
|
"""Single episode with one skill produces one row and correct mapping."""
|
||||||
|
annotations = {
|
||||||
|
0: EpisodeSkills(
|
||||||
|
episode_index=0,
|
||||||
|
description="Pick",
|
||||||
|
skills=[Skill("pick", 0.0, 1.0)],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
|
||||||
|
assert len(subtasks_df) == 1
|
||||||
|
assert subtasks_df.index.tolist() == ["pick"]
|
||||||
|
assert subtasks_df.loc["pick", "subtask_index"] == 0
|
||||||
|
assert skill_to_subtask_idx == {"pick": 0}
|
||||||
|
|
||||||
|
def test_multiple_episodes_overlapping_skills(self):
|
||||||
|
"""Multiple episodes with overlapping skill names yield unique sorted skills."""
|
||||||
|
annotations = {
|
||||||
|
0: EpisodeSkills(
|
||||||
|
episode_index=0,
|
||||||
|
description="Ep0",
|
||||||
|
skills=[
|
||||||
|
Skill("place", 0.0, 0.5),
|
||||||
|
Skill("pick", 0.5, 1.0),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
1: EpisodeSkills(
|
||||||
|
episode_index=1,
|
||||||
|
description="Ep1",
|
||||||
|
skills=[Skill("pick", 0.0, 1.0)],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
|
||||||
|
# Sorted order: pick, place
|
||||||
|
assert subtasks_df.index.tolist() == ["pick", "place"]
|
||||||
|
assert int(subtasks_df.loc["pick", "subtask_index"]) == 0
|
||||||
|
assert int(subtasks_df.loc["place", "subtask_index"]) == 1
|
||||||
|
assert skill_to_subtask_idx["pick"] == 0
|
||||||
|
assert skill_to_subtask_idx["place"] == 1
|
||||||
|
|
||||||
|
def test_skills_sorted_alphabetically(self):
|
||||||
|
"""Subtask rows are in alphabetical order by skill name."""
|
||||||
|
annotations = {
|
||||||
|
0: EpisodeSkills(
|
||||||
|
episode_index=0,
|
||||||
|
description="Ep",
|
||||||
|
skills=[
|
||||||
|
Skill("z_final", 0.0, 0.33),
|
||||||
|
Skill("a_first", 0.33, 0.66),
|
||||||
|
Skill("m_mid", 0.66, 1.0),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
subtasks_df, _ = create_subtasks_dataframe(annotations)
|
||||||
|
assert subtasks_df.index.tolist() == ["a_first", "m_mid", "z_final"]
|
||||||
|
assert list(subtasks_df["subtask_index"]) == [0, 1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveSubtasks:
|
||||||
|
"""Tests for save_subtasks in utils."""
|
||||||
|
|
||||||
|
def test_save_subtasks_creates_file(self, tmp_path):
|
||||||
|
"""save_subtasks writes meta/subtasks.parquet and creates parent dir."""
|
||||||
|
subtasks_df = pd.DataFrame(
|
||||||
|
[{"subtask": "pick", "subtask_index": 0}, {"subtask": "place", "subtask_index": 1}]
|
||||||
|
).set_index("subtask")
|
||||||
|
save_subtasks(subtasks_df, tmp_path)
|
||||||
|
out = tmp_path / "meta" / "subtasks.parquet"
|
||||||
|
assert out.exists()
|
||||||
|
read_df = pd.read_parquet(out)
|
||||||
|
pd.testing.assert_frame_equal(read_df.reset_index(), subtasks_df.reset_index())
|
||||||
|
|
||||||
|
def test_save_subtasks_content_matches(self, tmp_path):
|
||||||
|
"""Saved parquet round-trips with same content."""
|
||||||
|
subtasks_df = pd.DataFrame(
|
||||||
|
[{"subtask": "a", "subtask_index": 0}, {"subtask": "b", "subtask_index": 1}]
|
||||||
|
).set_index("subtask")
|
||||||
|
save_subtasks(subtasks_df, tmp_path)
|
||||||
|
read_df = pd.read_parquet(tmp_path / "meta" / "subtasks.parquet")
|
||||||
|
assert read_df.index.tolist() == subtasks_df.index.tolist()
|
||||||
|
assert list(read_df["subtask_index"]) == list(subtasks_df["subtask_index"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSubtaskIndexArray:
|
||||||
|
"""Tests for create_subtask_index_array in utils."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_with_episodes(self, tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Dataset with two episodes (10 frames each) for index-array tests."""
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "subtask_idx", features=features)
|
||||||
|
for _ in range(10):
|
||||||
|
dataset.add_frame({"state": torch.randn(2), "task": "Task A"})
|
||||||
|
dataset.save_episode()
|
||||||
|
for _ in range(10):
|
||||||
|
dataset.add_frame({"state": torch.randn(2), "task": "Task B"})
|
||||||
|
dataset.save_episode()
|
||||||
|
dataset.finalize()
|
||||||
|
return LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||||
|
|
||||||
|
def test_unannotated_all_minus_one(self, dataset_with_episodes):
|
||||||
|
"""With no annotations, all frame indices are -1."""
|
||||||
|
skill_to_subtask_idx = {"pick": 0, "place": 1}
|
||||||
|
arr = create_subtask_index_array(dataset_with_episodes, {}, skill_to_subtask_idx)
|
||||||
|
assert len(arr) == len(dataset_with_episodes)
|
||||||
|
assert arr.dtype == np.int64
|
||||||
|
assert np.all(arr == -1)
|
||||||
|
|
||||||
|
def test_annotated_episode_assigns_by_timestamp(self, dataset_with_episodes):
|
||||||
|
"""Frames in an annotated episode get subtask index from skill time ranges."""
|
||||||
|
# Dataset uses DEFAULT_FPS=30. Episode 0: 10 frames -> timestamps 0, 1/30, ..., 9/30 (~0.3s).
|
||||||
|
# Skills: "pick" [0, 0.2), "place" [0.2, 0.5). At 30 fps: 0.2s = 6 frames, so frames 0-5 = pick, 6-9 = place.
|
||||||
|
annotations = {
|
||||||
|
0: EpisodeSkills(
|
||||||
|
episode_index=0,
|
||||||
|
description="Pick and place",
|
||||||
|
skills=[
|
||||||
|
Skill("pick", 0.0, 0.2), # frames 0-5 at 30 fps
|
||||||
|
Skill("place", 0.2, 0.5), # frames 6-9 at 30 fps
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
skill_to_subtask_idx = {"pick": 0, "place": 1}
|
||||||
|
arr = create_subtask_index_array(dataset_with_episodes, annotations, skill_to_subtask_idx)
|
||||||
|
assert len(arr) == 20
|
||||||
|
# Episode 0: from_index=0, to_index=10 at 30 fps
|
||||||
|
for i in range(6):
|
||||||
|
assert arr[i] == 0, f"frame {i} should be pick"
|
||||||
|
for i in range(6, 10):
|
||||||
|
assert arr[i] == 1, f"frame {i} should be place"
|
||||||
|
# Episode 1 not annotated
|
||||||
|
for i in range(10, 20):
|
||||||
|
assert arr[i] == -1
|
||||||
|
|
||||||
|
def test_partial_annotations_leave_others_minus_one(self, dataset_with_episodes):
|
||||||
|
"""Only annotated episodes get non -1 indices; others stay -1."""
|
||||||
|
annotations = {
|
||||||
|
1: EpisodeSkills(
|
||||||
|
episode_index=1,
|
||||||
|
description="Place only",
|
||||||
|
skills=[Skill("place", 0.0, 1.0)],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
skill_to_subtask_idx = {"place": 0}
|
||||||
|
arr = create_subtask_index_array(dataset_with_episodes, annotations, skill_to_subtask_idx)
|
||||||
|
for i in range(10):
|
||||||
|
assert arr[i] == -1
|
||||||
|
for i in range(10, 20):
|
||||||
|
assert arr[i] == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user