clean subtask

This commit is contained in:
Jade Choghari
2026-02-09 10:55:22 +01:00
parent 6aa0cc267f
commit 4503019d18
2 changed files with 125 additions and 51 deletions
@@ -81,6 +81,7 @@ from rich.progress import Progress, SpinnerColumn, TextColumn
from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import SKILL_SEGMENTATION_PROMPT_TEMPLATE
# Skill Annotation Data Structures
@@ -141,7 +142,11 @@ class BaseVLM(ABC):
@abstractmethod
def segment_skills(
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""
Segment a video into atomic skills.
@@ -150,6 +155,7 @@ class BaseVLM(ABC):
video_path: Path to the video file
episode_duration: Total duration of the episode in seconds
coarse_goal: Optional high-level task description
subtask_labels: Optional list of allowed skill labels to use
Returns:
List of Skill objects representing atomic manipulation skills
@@ -158,7 +164,11 @@ class BaseVLM(ABC):
@abstractmethod
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""
Segment multiple videos into atomic skills in a single batch.
@@ -167,6 +177,7 @@ class BaseVLM(ABC):
video_paths: List of paths to video files
episode_durations: List of episode durations in seconds
coarse_goal: Optional high-level task description
subtask_labels: Optional list of allowed skill labels to use
Returns:
List of skill lists, one for each video
@@ -174,43 +185,36 @@ class BaseVLM(ABC):
pass
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
"""Create the prompt for skill segmentation."""
def create_skill_segmentation_prompt(
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> str:
"""Create the prompt for skill segmentation.
Args:
coarse_goal: Optional high-level task description.
subtask_labels: Optional list of allowed skill/subtask labels. When provided,
the model is instructed to use only these labels (choosing the best match
for each segment).
Returns:
The formatted prompt string.
"""
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
return textwrap.dedent(f"""\
# Role
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
# Task
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
- Last approximately 1-3 seconds
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
- Have precise start and end timestamps
# Requirements
1. **Atomic Actions**: Each skill should be a single, indivisible action
2. **Complete Coverage**: Skills must cover the entire video duration with no gaps
3. **Boundary Consistency**: The end of one skill equals the start of the next
4. **Natural Language**: Use clear, descriptive names for each skill
5. **Timestamps**: Use seconds (float) for all timestamps
# Output Format
After your analysis, output ONLY valid JSON with this exact structure:
```json
{{
"skills": [
{{"name": "skill description", "start": 0.0, "end": 1.5}},
{{"name": "another skill", "start": 1.5, "end": 3.2}}
]
}}
```
The first skill must start at 0.0 and the last skill must end at the video duration.
""")
if subtask_labels:
labels_str = ", ".join(f'"{label}"' for label in subtask_labels)
subtask_labels_section = (
f'6. **Allowed labels**: Use ONLY the following skill names '
f"(choose the best match for each segment): {labels_str}\n\n"
)
else:
subtask_labels_section = ""
return textwrap.dedent(
SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
goal_context=goal_context,
subtask_labels_section=subtask_labels_section,
)
)
# Qwen2-VL Implementation
@@ -238,10 +242,14 @@ class Qwen2VL(BaseVLM):
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def segment_skills(
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment video into skills using Qwen2-VL."""
prompt = create_skill_segmentation_prompt(coarse_goal)
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
@@ -279,10 +287,14 @@ class Qwen2VL(BaseVLM):
return self._parse_skills_response(response)
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
prompt = create_skill_segmentation_prompt(coarse_goal)
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
# Create messages for each video
all_messages = []
@@ -394,10 +406,14 @@ class Qwen3VL(BaseVLM):
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def segment_skills(
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment video into skills using Qwen3-VL."""
prompt = create_skill_segmentation_prompt(coarse_goal)
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
@@ -435,10 +451,14 @@ class Qwen3VL(BaseVLM):
return self._parse_skills_response(response)
def segment_skills_batch(
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
prompt = create_skill_segmentation_prompt(coarse_goal)
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
# Create messages for each video
all_messages = []
@@ -679,6 +699,7 @@ class SkillAnnotator:
video_key: str,
episodes: list[int] | None = None,
skip_existing: bool = False,
subtask_labels: list[str] | None = None,
) -> dict[int, EpisodeSkills]:
"""
Annotate all episodes in a dataset with skill labels using batched processing.
@@ -688,6 +709,7 @@ class SkillAnnotator:
video_key: Key for video observations (e.g., "observation.images.base")
episodes: Specific episode indices to annotate (None = all)
skip_existing: Skip episodes that already have skill annotations
subtask_labels: Optional list of allowed skill labels (VLM will use only these)
Returns:
Dictionary mapping episode index to EpisodeSkills
@@ -732,7 +754,7 @@ class SkillAnnotator:
try:
batch_annotations = self._annotate_episodes_batch(
dataset, batch_episodes, video_key, coarse_goal
dataset, batch_episodes, video_key, coarse_goal, subtask_labels
)
for ep_idx in batch_episodes:
@@ -754,7 +776,9 @@ class SkillAnnotator:
# Fallback: process episodes one by one
for ep_idx in batch_episodes:
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
skills = self._annotate_episode(
dataset, ep_idx, video_key, coarse_goal, subtask_labels
)
if skills:
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
@@ -778,7 +802,9 @@ class SkillAnnotator:
for ep_idx, error_msg in list(failed_episodes.items()):
self.console.print(f"[cyan]Retry attempt for episode {ep_idx} (previous error: {error_msg})[/cyan]")
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
skills = self._annotate_episode(
dataset, ep_idx, video_key, coarse_goal, subtask_labels
)
if skills:
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
@@ -823,6 +849,7 @@ class SkillAnnotator:
episode_indices: list[int],
video_key: str,
coarse_goal: str,
subtask_labels: list[str] | None = None,
) -> dict[int, list[Skill]]:
"""Annotate multiple episodes with skill labels in a batch."""
# Extract all videos for this batch
@@ -863,7 +890,9 @@ class SkillAnnotator:
try:
# Run VLM skill segmentation in batch
all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal)
all_skills = self.vlm.segment_skills_batch(
extracted_paths, durations, coarse_goal, subtask_labels
)
# Map results back to episode indices
results = {}
@@ -884,6 +913,7 @@ class SkillAnnotator:
episode_index: int,
video_key: str,
coarse_goal: str,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Annotate a single episode with skill labels."""
# Get video path and timestamps for this episode
@@ -905,7 +935,9 @@ class SkillAnnotator:
try:
# Run VLM skill segmentation
skills = self.vlm.segment_skills(extracted_path, duration, coarse_goal)
skills = self.vlm.segment_skills(
extracted_path, duration, coarse_goal, subtask_labels
)
return skills
finally:
# Clean up temporary file
@@ -1269,6 +1301,13 @@ def main():
action="store_true",
help="Skip episodes that already have annotations",
)
parser.add_argument(
"--subtask-labels",
type=str,
nargs="+",
default=None,
help="Optional list of allowed skill labels (VLM will use only these; space-separated)",
)
# Output options
parser.add_argument(
@@ -1325,6 +1364,7 @@ def main():
video_key=args.video_key,
episodes=args.episodes,
skip_existing=args.skip_existing,
subtask_labels=args.subtask_labels,
)
# Save annotations
+34
View File
@@ -92,3 +92,37 @@ LIBERO_KEY_JOINTS_POS = "robot_state/joints/pos"
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
# Skill segmentation prompt template for VLM-based subtask annotation
# Placeholders: {goal_context}, {subtask_labels_section}
SKILL_SEGMENTATION_PROMPT_TEMPLATE = """# Role
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
# Task
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
- Last approximately 1-3 seconds
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
- Have precise start and end timestamps
# Requirements
1. **Atomic Actions**: Each skill should be a single, indivisible action
2. **Complete Coverage**: Skills must cover the entire video duration with no gaps
3. **Boundary Consistency**: The end of one skill equals the start of the next
4. **Natural Language**: Use clear, descriptive names for each skill
5. **Timestamps**: Use seconds (float) for all timestamps
{subtask_labels_section}
# Output Format
After your analysis, output ONLY valid JSON with this exact structure:
```json
{{
"skills": [
{{"name": "skill description", "start": 0.0, "end": 1.5}},
{{"name": "another skill", "start": 1.5, "end": 3.2}}
]
}}
```
The first skill must start at 0.0 and the last skill must end at the video duration.
"""