mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
clean subtask
This commit is contained in:
@@ -81,6 +81,7 @@ from rich.progress import Progress, SpinnerColumn, TextColumn
|
|||||||
|
|
||||||
from lerobot.datasets.dataset_tools import add_features
|
from lerobot.datasets.dataset_tools import add_features
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.utils.constants import SKILL_SEGMENTATION_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
# Skill Annotation Data Structures
|
# Skill Annotation Data Structures
|
||||||
@@ -141,7 +142,11 @@ class BaseVLM(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def segment_skills(
|
def segment_skills(
|
||||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
self,
|
||||||
|
video_path: Path,
|
||||||
|
episode_duration: float,
|
||||||
|
coarse_goal: str | None = None,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> list[Skill]:
|
) -> list[Skill]:
|
||||||
"""
|
"""
|
||||||
Segment a video into atomic skills.
|
Segment a video into atomic skills.
|
||||||
@@ -150,6 +155,7 @@ class BaseVLM(ABC):
|
|||||||
video_path: Path to the video file
|
video_path: Path to the video file
|
||||||
episode_duration: Total duration of the episode in seconds
|
episode_duration: Total duration of the episode in seconds
|
||||||
coarse_goal: Optional high-level task description
|
coarse_goal: Optional high-level task description
|
||||||
|
subtask_labels: Optional list of allowed skill labels to use
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Skill objects representing atomic manipulation skills
|
List of Skill objects representing atomic manipulation skills
|
||||||
@@ -158,7 +164,11 @@ class BaseVLM(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def segment_skills_batch(
|
def segment_skills_batch(
|
||||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
self,
|
||||||
|
video_paths: list[Path],
|
||||||
|
episode_durations: list[float],
|
||||||
|
coarse_goal: str | None = None,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> list[list[Skill]]:
|
) -> list[list[Skill]]:
|
||||||
"""
|
"""
|
||||||
Segment multiple videos into atomic skills in a single batch.
|
Segment multiple videos into atomic skills in a single batch.
|
||||||
@@ -167,6 +177,7 @@ class BaseVLM(ABC):
|
|||||||
video_paths: List of paths to video files
|
video_paths: List of paths to video files
|
||||||
episode_durations: List of episode durations in seconds
|
episode_durations: List of episode durations in seconds
|
||||||
coarse_goal: Optional high-level task description
|
coarse_goal: Optional high-level task description
|
||||||
|
subtask_labels: Optional list of allowed skill labels to use
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of skill lists, one for each video
|
List of skill lists, one for each video
|
||||||
@@ -174,43 +185,36 @@ class BaseVLM(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
|
def create_skill_segmentation_prompt(
|
||||||
"""Create the prompt for skill segmentation."""
|
coarse_goal: str | None = None,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create the prompt for skill segmentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coarse_goal: Optional high-level task description.
|
||||||
|
subtask_labels: Optional list of allowed skill/subtask labels. When provided,
|
||||||
|
the model is instructed to use only these labels (choosing the best match
|
||||||
|
for each segment).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The formatted prompt string.
|
||||||
|
"""
|
||||||
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
|
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
|
||||||
|
if subtask_labels:
|
||||||
return textwrap.dedent(f"""\
|
labels_str = ", ".join(f'"{label}"' for label in subtask_labels)
|
||||||
# Role
|
subtask_labels_section = (
|
||||||
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
|
f'6. **Allowed labels**: Use ONLY the following skill names '
|
||||||
|
f"(choose the best match for each segment): {labels_str}\n\n"
|
||||||
# Task
|
)
|
||||||
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
|
else:
|
||||||
- Last approximately 1-3 seconds
|
subtask_labels_section = ""
|
||||||
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
|
return textwrap.dedent(
|
||||||
- Have precise start and end timestamps
|
SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
|
||||||
|
goal_context=goal_context,
|
||||||
# Requirements
|
subtask_labels_section=subtask_labels_section,
|
||||||
1. **Atomic Actions**: Each skill should be a single, indivisible action
|
)
|
||||||
2. **Complete Coverage**: Skills must cover the entire video duration with no gaps
|
)
|
||||||
3. **Boundary Consistency**: The end of one skill equals the start of the next
|
|
||||||
4. **Natural Language**: Use clear, descriptive names for each skill
|
|
||||||
5. **Timestamps**: Use seconds (float) for all timestamps
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Output Format
|
|
||||||
After your analysis, output ONLY valid JSON with this exact structure:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"skills": [
|
|
||||||
{{"name": "skill description", "start": 0.0, "end": 1.5}},
|
|
||||||
{{"name": "another skill", "start": 1.5, "end": 3.2}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
|
|
||||||
The first skill must start at 0.0 and the last skill must end at the video duration.
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
# Qwen2-VL Implementation
|
# Qwen2-VL Implementation
|
||||||
@@ -238,10 +242,14 @@ class Qwen2VL(BaseVLM):
|
|||||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||||
|
|
||||||
def segment_skills(
|
def segment_skills(
|
||||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
self,
|
||||||
|
video_path: Path,
|
||||||
|
episode_duration: float,
|
||||||
|
coarse_goal: str | None = None,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> list[Skill]:
|
) -> list[Skill]:
|
||||||
"""Segment video into skills using Qwen2-VL."""
|
"""Segment video into skills using Qwen2-VL."""
|
||||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -279,10 +287,14 @@ class Qwen2VL(BaseVLM):
|
|||||||
return self._parse_skills_response(response)
|
return self._parse_skills_response(response)
|
||||||
|
|
||||||
def segment_skills_batch(
|
def segment_skills_batch(
|
||||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
self,
|
||||||
|
video_paths: list[Path],
|
||||||
|
episode_durations: list[float],
|
||||||
|
coarse_goal: str | None = None,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> list[list[Skill]]:
|
) -> list[list[Skill]]:
|
||||||
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||||
|
|
||||||
# Create messages for each video
|
# Create messages for each video
|
||||||
all_messages = []
|
all_messages = []
|
||||||
@@ -394,10 +406,14 @@ class Qwen3VL(BaseVLM):
|
|||||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||||
|
|
||||||
def segment_skills(
|
def segment_skills(
|
||||||
self, video_path: Path, episode_duration: float, coarse_goal: str | None = None
|
self,
|
||||||
|
video_path: Path,
|
||||||
|
episode_duration: float,
|
||||||
|
coarse_goal: str | None = None,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> list[Skill]:
|
) -> list[Skill]:
|
||||||
"""Segment video into skills using Qwen3-VL."""
|
"""Segment video into skills using Qwen3-VL."""
|
||||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -435,10 +451,14 @@ class Qwen3VL(BaseVLM):
|
|||||||
return self._parse_skills_response(response)
|
return self._parse_skills_response(response)
|
||||||
|
|
||||||
def segment_skills_batch(
|
def segment_skills_batch(
|
||||||
self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None
|
self,
|
||||||
|
video_paths: list[Path],
|
||||||
|
episode_durations: list[float],
|
||||||
|
coarse_goal: str | None = None,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> list[list[Skill]]:
|
) -> list[list[Skill]]:
|
||||||
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels)
|
||||||
|
|
||||||
# Create messages for each video
|
# Create messages for each video
|
||||||
all_messages = []
|
all_messages = []
|
||||||
@@ -679,6 +699,7 @@ class SkillAnnotator:
|
|||||||
video_key: str,
|
video_key: str,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
skip_existing: bool = False,
|
skip_existing: bool = False,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> dict[int, EpisodeSkills]:
|
) -> dict[int, EpisodeSkills]:
|
||||||
"""
|
"""
|
||||||
Annotate all episodes in a dataset with skill labels using batched processing.
|
Annotate all episodes in a dataset with skill labels using batched processing.
|
||||||
@@ -688,6 +709,7 @@ class SkillAnnotator:
|
|||||||
video_key: Key for video observations (e.g., "observation.images.base")
|
video_key: Key for video observations (e.g., "observation.images.base")
|
||||||
episodes: Specific episode indices to annotate (None = all)
|
episodes: Specific episode indices to annotate (None = all)
|
||||||
skip_existing: Skip episodes that already have skill annotations
|
skip_existing: Skip episodes that already have skill annotations
|
||||||
|
subtask_labels: Optional list of allowed skill labels (VLM will use only these)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping episode index to EpisodeSkills
|
Dictionary mapping episode index to EpisodeSkills
|
||||||
@@ -732,7 +754,7 @@ class SkillAnnotator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
batch_annotations = self._annotate_episodes_batch(
|
batch_annotations = self._annotate_episodes_batch(
|
||||||
dataset, batch_episodes, video_key, coarse_goal
|
dataset, batch_episodes, video_key, coarse_goal, subtask_labels
|
||||||
)
|
)
|
||||||
|
|
||||||
for ep_idx in batch_episodes:
|
for ep_idx in batch_episodes:
|
||||||
@@ -754,7 +776,9 @@ class SkillAnnotator:
|
|||||||
# Fallback: process episodes one by one
|
# Fallback: process episodes one by one
|
||||||
for ep_idx in batch_episodes:
|
for ep_idx in batch_episodes:
|
||||||
try:
|
try:
|
||||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
skills = self._annotate_episode(
|
||||||
|
dataset, ep_idx, video_key, coarse_goal, subtask_labels
|
||||||
|
)
|
||||||
if skills:
|
if skills:
|
||||||
annotations[ep_idx] = EpisodeSkills(
|
annotations[ep_idx] = EpisodeSkills(
|
||||||
episode_index=ep_idx,
|
episode_index=ep_idx,
|
||||||
@@ -778,7 +802,9 @@ class SkillAnnotator:
|
|||||||
for ep_idx, error_msg in list(failed_episodes.items()):
|
for ep_idx, error_msg in list(failed_episodes.items()):
|
||||||
self.console.print(f"[cyan]Retry attempt for episode {ep_idx} (previous error: {error_msg})[/cyan]")
|
self.console.print(f"[cyan]Retry attempt for episode {ep_idx} (previous error: {error_msg})[/cyan]")
|
||||||
try:
|
try:
|
||||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
skills = self._annotate_episode(
|
||||||
|
dataset, ep_idx, video_key, coarse_goal, subtask_labels
|
||||||
|
)
|
||||||
if skills:
|
if skills:
|
||||||
annotations[ep_idx] = EpisodeSkills(
|
annotations[ep_idx] = EpisodeSkills(
|
||||||
episode_index=ep_idx,
|
episode_index=ep_idx,
|
||||||
@@ -823,6 +849,7 @@ class SkillAnnotator:
|
|||||||
episode_indices: list[int],
|
episode_indices: list[int],
|
||||||
video_key: str,
|
video_key: str,
|
||||||
coarse_goal: str,
|
coarse_goal: str,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> dict[int, list[Skill]]:
|
) -> dict[int, list[Skill]]:
|
||||||
"""Annotate multiple episodes with skill labels in a batch."""
|
"""Annotate multiple episodes with skill labels in a batch."""
|
||||||
# Extract all videos for this batch
|
# Extract all videos for this batch
|
||||||
@@ -863,7 +890,9 @@ class SkillAnnotator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Run VLM skill segmentation in batch
|
# Run VLM skill segmentation in batch
|
||||||
all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal)
|
all_skills = self.vlm.segment_skills_batch(
|
||||||
|
extracted_paths, durations, coarse_goal, subtask_labels
|
||||||
|
)
|
||||||
|
|
||||||
# Map results back to episode indices
|
# Map results back to episode indices
|
||||||
results = {}
|
results = {}
|
||||||
@@ -884,6 +913,7 @@ class SkillAnnotator:
|
|||||||
episode_index: int,
|
episode_index: int,
|
||||||
video_key: str,
|
video_key: str,
|
||||||
coarse_goal: str,
|
coarse_goal: str,
|
||||||
|
subtask_labels: list[str] | None = None,
|
||||||
) -> list[Skill]:
|
) -> list[Skill]:
|
||||||
"""Annotate a single episode with skill labels."""
|
"""Annotate a single episode with skill labels."""
|
||||||
# Get video path and timestamps for this episode
|
# Get video path and timestamps for this episode
|
||||||
@@ -905,7 +935,9 @@ class SkillAnnotator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Run VLM skill segmentation
|
# Run VLM skill segmentation
|
||||||
skills = self.vlm.segment_skills(extracted_path, duration, coarse_goal)
|
skills = self.vlm.segment_skills(
|
||||||
|
extracted_path, duration, coarse_goal, subtask_labels
|
||||||
|
)
|
||||||
return skills
|
return skills
|
||||||
finally:
|
finally:
|
||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
@@ -1269,6 +1301,13 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Skip episodes that already have annotations",
|
help="Skip episodes that already have annotations",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--subtask-labels",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Optional list of allowed skill labels (VLM will use only these; space-separated)",
|
||||||
|
)
|
||||||
|
|
||||||
# Output options
|
# Output options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -1325,6 +1364,7 @@ def main():
|
|||||||
video_key=args.video_key,
|
video_key=args.video_key,
|
||||||
episodes=args.episodes,
|
episodes=args.episodes,
|
||||||
skip_existing=args.skip_existing,
|
skip_existing=args.skip_existing,
|
||||||
|
subtask_labels=args.subtask_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save annotations
|
# Save annotations
|
||||||
|
|||||||
@@ -92,3 +92,37 @@ LIBERO_KEY_JOINTS_POS = "robot_state/joints/pos"
|
|||||||
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
|
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
|
||||||
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
|
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
|
||||||
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
|
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
|
||||||
|
|
||||||
|
# Skill segmentation prompt template for VLM-based subtask annotation
|
||||||
|
# Placeholders: {goal_context}, {subtask_labels_section}
|
||||||
|
SKILL_SEGMENTATION_PROMPT_TEMPLATE = """# Role
|
||||||
|
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
|
||||||
|
|
||||||
|
# Task
|
||||||
|
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
|
||||||
|
- Last approximately 1-3 seconds
|
||||||
|
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
|
||||||
|
- Have precise start and end timestamps
|
||||||
|
|
||||||
|
# Requirements
|
||||||
|
1. **Atomic Actions**: Each skill should be a single, indivisible action
|
||||||
|
2. **Complete Coverage**: Skills must cover the entire video duration with no gaps
|
||||||
|
3. **Boundary Consistency**: The end of one skill equals the start of the next
|
||||||
|
4. **Natural Language**: Use clear, descriptive names for each skill
|
||||||
|
5. **Timestamps**: Use seconds (float) for all timestamps
|
||||||
|
{subtask_labels_section}
|
||||||
|
|
||||||
|
# Output Format
|
||||||
|
After your analysis, output ONLY valid JSON with this exact structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"skills": [
|
||||||
|
{{"name": "skill description", "start": 0.0, "end": 1.5}},
|
||||||
|
{{"name": "another skill", "start": 1.5, "end": 3.2}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
The first skill must start at 0.0 and the last skill must end at the video duration.
|
||||||
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user