mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
clean subtask
This commit is contained in:
@@ -81,6 +81,7 @@ from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from lerobot.datasets.dataset_tools import add_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import SKILL_SEGMENTATION_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
# Skill Annotation Data Structures
|
||||
@@ -141,7 +142,11 @@ class BaseVLM(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def segment_skills(
|
||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
||||
self,
|
||||
video_path: Path,
|
||||
episode_duration: float,
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""
|
||||
Segment a video into atomic skills.
|
||||
@@ -150,6 +155,7 @@ class BaseVLM(ABC):
|
||||
video_path: Path to the video file
|
||||
episode_duration: Total duration of the episode in seconds
|
||||
coarse_goal: Optional high-level task description
|
||||
subtask_labels: Optional list of allowed skill labels to use
|
||||
|
||||
Returns:
|
||||
List of Skill objects representing atomic manipulation skills
|
||||
@@ -158,7 +164,11 @@ class BaseVLM(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def segment_skills_batch(
|
||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||
self,
|
||||
video_paths: list[Path],
|
||||
episode_durations: list[float],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""
|
||||
Segment multiple videos into atomic skills in a single batch.
|
||||
@@ -167,6 +177,7 @@ class BaseVLM(ABC):
|
||||
video_paths: List of paths to video files
|
||||
episode_durations: List of episode durations in seconds
|
||||
coarse_goal: Optional high-level task description
|
||||
subtask_labels: Optional list of allowed skill labels to use
|
||||
|
||||
Returns:
|
||||
List of skill lists, one for each video
|
||||
@@ -174,43 +185,36 @@ class BaseVLM(ABC):
|
||||
pass
|
||||
|
||||
|
||||
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
|
||||
"""Create the prompt for skill segmentation."""
|
||||
def create_skill_segmentation_prompt(
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Create the prompt for skill segmentation.
|
||||
|
||||
Args:
|
||||
coarse_goal: Optional high-level task description.
|
||||
subtask_labels: Optional list of allowed skill/subtask labels. When provided,
|
||||
the model is instructed to use only these labels (choosing the best match
|
||||
for each segment).
|
||||
|
||||
Returns:
|
||||
The formatted prompt string.
|
||||
"""
|
||||
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
|
||||
|
||||
return textwrap.dedent(f"""\
|
||||
# Role
|
||||
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
|
||||
|
||||
# Task
|
||||
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
|
||||
- Last approximately 1-3 seconds
|
||||
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
|
||||
- Have precise start and end timestamps
|
||||
|
||||
# Requirements
|
||||
1. **Atomic Actions**: Each skill should be a single, indivisible action
|
||||
2. **Complete Coverage**: Skills must cover the entire video duration with no gaps
|
||||
3. **Boundary Consistency**: The end of one skill equals the start of the next
|
||||
4. **Natural Language**: Use clear, descriptive names for each skill
|
||||
5. **Timestamps**: Use seconds (float) for all timestamps
|
||||
|
||||
|
||||
|
||||
# Output Format
|
||||
After your analysis, output ONLY valid JSON with this exact structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"skills": [
|
||||
{{"name": "skill description", "start": 0.0, "end": 1.5}},
|
||||
{{"name": "another skill", "start": 1.5, "end": 3.2}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
The first skill must start at 0.0 and the last skill must end at the video duration.
|
||||
""")
|
||||
if subtask_labels:
|
||||
labels_str = ", ".join(f'"{label}"' for label in subtask_labels)
|
||||
subtask_labels_section = (
|
||||
f'6. **Allowed labels**: Use ONLY the following skill names '
|
||||
f"(choose the best match for each segment): {labels_str}\n\n"
|
||||
)
|
||||
else:
|
||||
subtask_labels_section = ""
|
||||
return textwrap.dedent(
|
||||
SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
|
||||
goal_context=goal_context,
|
||||
subtask_labels_section=subtask_labels_section,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Qwen2-VL Implementation
|
||||
@@ -238,10 +242,14 @@ class Qwen2VL(BaseVLM):
|
||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||
|
||||
def segment_skills(
|
||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
||||
self,
|
||||
video_path: Path,
|
||||
episode_duration: float,
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""Segment video into skills using Qwen2-VL."""
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||
|
||||
messages = [
|
||||
@@ -279,10 +287,14 @@ class Qwen2VL(BaseVLM):
|
||||
return self._parse_skills_response(response)
|
||||
|
||||
def segment_skills_batch(
|
||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||
self,
|
||||
video_paths: list[Path],
|
||||
episode_durations: list[float],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||
|
||||
# Create messages for each video
|
||||
all_messages = []
|
||||
@@ -394,10 +406,14 @@ class Qwen3VL(BaseVLM):
|
||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||
|
||||
def segment_skills(
|
||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
||||
self,
|
||||
video_path: Path,
|
||||
episode_duration: float,
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""Segment video into skills using Qwen3-VL."""
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||
|
||||
messages = [
|
||||
@@ -435,10 +451,14 @@ class Qwen3VL(BaseVLM):
|
||||
return self._parse_skills_response(response)
|
||||
|
||||
def segment_skills_batch(
|
||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
||||
self,
|
||||
video_paths: list[Path],
|
||||
episode_durations: list[float],
|
||||
coarse_goal: str | None = None,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||
|
||||
# Create messages for each video
|
||||
all_messages = []
|
||||
@@ -679,6 +699,7 @@ class SkillAnnotator:
|
||||
video_key: str,
|
||||
episodes: list[int] | None = None,
|
||||
skip_existing: bool = False,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> dict[int, EpisodeSkills]:
|
||||
"""
|
||||
Annotate all episodes in a dataset with skill labels using batched processing.
|
||||
@@ -688,6 +709,7 @@ class SkillAnnotator:
|
||||
video_key: Key for video observations (e.g., "observation.images.base")
|
||||
episodes: Specific episode indices to annotate (None = all)
|
||||
skip_existing: Skip episodes that already have skill annotations
|
||||
subtask_labels: Optional list of allowed skill labels (VLM will use only these)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping episode index to EpisodeSkills
|
||||
@@ -732,7 +754,7 @@ class SkillAnnotator:
|
||||
|
||||
try:
|
||||
batch_annotations = self._annotate_episodes_batch(
|
||||
dataset, batch_episodes, video_key, coarse_goal
|
||||
dataset, batch_episodes, video_key, coarse_goal, subtask_labels
|
||||
)
|
||||
|
||||
for ep_idx in batch_episodes:
|
||||
@@ -754,7 +776,9 @@ class SkillAnnotator:
|
||||
# Fallback: process episodes one by one
|
||||
for ep_idx in batch_episodes:
|
||||
try:
|
||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
||||
skills = self._annotate_episode(
|
||||
dataset, ep_idx, video_key, coarse_goal, subtask_labels
|
||||
)
|
||||
if skills:
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
@@ -778,7 +802,9 @@ class SkillAnnotator:
|
||||
for ep_idx, error_msg in list(failed_episodes.items()):
|
||||
self.console.print(f"[cyan]Retry attempt for episode {ep_idx} (previous error: {error_msg})[/cyan]")
|
||||
try:
|
||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
||||
skills = self._annotate_episode(
|
||||
dataset, ep_idx, video_key, coarse_goal, subtask_labels
|
||||
)
|
||||
if skills:
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
@@ -823,6 +849,7 @@ class SkillAnnotator:
|
||||
episode_indices: list[int],
|
||||
video_key: str,
|
||||
coarse_goal: str,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> dict[int, list[Skill]]:
|
||||
"""Annotate multiple episodes with skill labels in a batch."""
|
||||
# Extract all videos for this batch
|
||||
@@ -863,7 +890,9 @@ class SkillAnnotator:
|
||||
|
||||
try:
|
||||
# Run VLM skill segmentation in batch
|
||||
all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal)
|
||||
all_skills = self.vlm.segment_skills_batch(
|
||||
extracted_paths, durations, coarse_goal, subtask_labels
|
||||
)
|
||||
|
||||
# Map results back to episode indices
|
||||
results = {}
|
||||
@@ -884,6 +913,7 @@ class SkillAnnotator:
|
||||
episode_index: int,
|
||||
video_key: str,
|
||||
coarse_goal: str,
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[Skill]:
|
||||
"""Annotate a single episode with skill labels."""
|
||||
# Get video path and timestamps for this episode
|
||||
@@ -905,7 +935,9 @@ class SkillAnnotator:
|
||||
|
||||
try:
|
||||
# Run VLM skill segmentation
|
||||
skills = self.vlm.segment_skills(extracted_path, duration, coarse_goal)
|
||||
skills = self.vlm.segment_skills(
|
||||
extracted_path, duration, coarse_goal, subtask_labels
|
||||
)
|
||||
return skills
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
@@ -1269,6 +1301,13 @@ def main():
|
||||
action="store_true",
|
||||
help="Skip episodes that already have annotations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subtask-labels",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Optional list of allowed skill labels (VLM will use only these; space-separated)",
|
||||
)
|
||||
|
||||
# Output options
|
||||
parser.add_argument(
|
||||
@@ -1325,6 +1364,7 @@ def main():
|
||||
video_key=args.video_key,
|
||||
episodes=args.episodes,
|
||||
skip_existing=args.skip_existing,
|
||||
subtask_labels=args.subtask_labels,
|
||||
)
|
||||
|
||||
# Save annotations
|
||||
|
||||
@@ -92,3 +92,37 @@ LIBERO_KEY_JOINTS_POS = "robot_state/joints/pos"
|
||||
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
|
||||
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
|
||||
|
||||
# Skill segmentation prompt template for VLM-based subtask annotation
|
||||
# Placeholders: {goal_context}, {subtask_labels_section}
|
||||
SKILL_SEGMENTATION_PROMPT_TEMPLATE = """# Role
|
||||
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
|
||||
|
||||
# Task
|
||||
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
|
||||
- Last approximately 1-3 seconds
|
||||
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
|
||||
- Have precise start and end timestamps
|
||||
|
||||
# Requirements
|
||||
1. **Atomic Actions**: Each skill should be a single, indivisible action
|
||||
2. **Complete Coverage**: Skills must cover the entire video duration with no gaps
|
||||
3. **Boundary Consistency**: The end of one skill equals the start of the next
|
||||
4. **Natural Language**: Use clear, descriptive names for each skill
|
||||
5. **Timestamps**: Use seconds (float) for all timestamps
|
||||
{subtask_labels_section}
|
||||
|
||||
# Output Format
|
||||
After your analysis, output ONLY valid JSON with this exact structure:
|
||||
|
||||
```json
|
||||
{{
|
||||
"skills": [
|
||||
{{"name": "skill description", "start": 0.0, "end": 1.5}},
|
||||
{{"name": "another skill", "start": 1.5, "end": 3.2}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
The first skill must start at 0.0 and the last skill must end at the video duration.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user