mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
simplify and cleanup code and move compute_temporal_proportions to utils
This commit is contained in:
@@ -72,13 +72,11 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
@@ -86,24 +84,8 @@ from rich.tree import Tree
|
|||||||
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions
|
||||||
# Pydantic Models for SARM-style Annotation
|
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||||
class Timestamp(BaseModel):
|
|
||||||
"""Timestamp in MM:SS or SS format"""
|
|
||||||
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
|
||||||
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
|
||||||
|
|
||||||
|
|
||||||
class Subtask(BaseModel):
|
|
||||||
"""Individual subtask/stage - must use EXACT names from provided list"""
|
|
||||||
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
|
||||||
timestamps: Timestamp
|
|
||||||
|
|
||||||
|
|
||||||
class SubtaskAnnotation(BaseModel):
|
|
||||||
"""Complete annotation for a robot manipulation episode"""
|
|
||||||
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
|
||||||
|
|
||||||
|
|
||||||
def create_sarm_prompt(subtask_list: list[str]) -> str:
|
def create_sarm_prompt(subtask_list: list[str]) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -769,59 +751,6 @@ def worker_process_episodes(
|
|||||||
return annotations
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
|
|
||||||
"""
|
|
||||||
Compute average temporal proportion for each subtask across all episodes.
|
|
||||||
This is the key insight from SARM - use semantic subtasks instead of frame indices.
|
|
||||||
"""
|
|
||||||
# Collect all proportions per subtask
|
|
||||||
subtask_proportions = {}
|
|
||||||
|
|
||||||
for annotation in annotations.values():
|
|
||||||
# Calculate total episode duration
|
|
||||||
total_duration = 0
|
|
||||||
durations = {}
|
|
||||||
|
|
||||||
for subtask in annotation.subtasks:
|
|
||||||
# Parse timestamps
|
|
||||||
start_parts = subtask.timestamps.start.split(":")
|
|
||||||
end_parts = subtask.timestamps.end.split(":")
|
|
||||||
|
|
||||||
if len(start_parts) == 2:
|
|
||||||
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
|
|
||||||
else:
|
|
||||||
start_seconds = int(start_parts[0])
|
|
||||||
|
|
||||||
if len(end_parts) == 2:
|
|
||||||
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
|
|
||||||
else:
|
|
||||||
end_seconds = int(end_parts[0])
|
|
||||||
|
|
||||||
duration = end_seconds - start_seconds
|
|
||||||
durations[subtask.name] = duration
|
|
||||||
total_duration += duration
|
|
||||||
|
|
||||||
# Calculate proportions for this episode
|
|
||||||
if total_duration > 0:
|
|
||||||
for name, duration in durations.items():
|
|
||||||
if name not in subtask_proportions:
|
|
||||||
subtask_proportions[name] = []
|
|
||||||
subtask_proportions[name].append(duration / total_duration)
|
|
||||||
|
|
||||||
# Average across episodes
|
|
||||||
avg_proportions = {
|
|
||||||
name: sum(props) / len(props)
|
|
||||||
for name, props in subtask_proportions.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Normalize to sum to 1.0
|
|
||||||
total = sum(avg_proportions.values())
|
|
||||||
if total > 0:
|
|
||||||
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
|
||||||
|
|
||||||
return avg_proportions
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
|
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
|
||||||
@@ -1185,4 +1114,3 @@ Performance Tips:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from transformers import CLIPModel, CLIPProcessor
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
|
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
class SARMTransformer(nn.Module):
|
class SARMTransformer(nn.Module):
|
||||||
|
|||||||
@@ -14,8 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -23,6 +21,7 @@ from PIL import Image
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from transformers import CLIPModel, CLIPProcessor
|
from transformers import CLIPModel, CLIPProcessor
|
||||||
|
|
||||||
|
from lerobot.processor.core import TransitionKey
|
||||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||||
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
@@ -119,7 +118,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
frame_indices: np.ndarray,
|
frame_indices: np.ndarray,
|
||||||
episode_indices: np.ndarray,
|
episode_indices: np.ndarray,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
is_batch: bool,
|
|
||||||
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Compute episode metadata for all samples.
|
"""Compute episode metadata for all samples.
|
||||||
|
|
||||||
@@ -140,10 +138,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
absolute_indices_list.append(abs_indices)
|
absolute_indices_list.append(abs_indices)
|
||||||
remaining_lengths.append(ep_end - abs_indices[0].item())
|
remaining_lengths.append(ep_end - abs_indices[0].item())
|
||||||
|
|
||||||
if is_batch:
|
|
||||||
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
|
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
|
||||||
else:
|
|
||||||
return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0]
|
|
||||||
|
|
||||||
def _compute_stage_and_progress_for_frame(
|
def _compute_stage_and_progress_for_frame(
|
||||||
self,
|
self,
|
||||||
@@ -189,7 +184,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
cumulative_progress = compute_cumulative_progress_batch(
|
cumulative_progress = compute_cumulative_progress_batch(
|
||||||
tau, stage_idx, temporal_proportions_list
|
tau, stage_idx, temporal_proportions_list
|
||||||
)
|
)
|
||||||
|
|
||||||
return stage_idx, cumulative_progress
|
return stage_idx, cumulative_progress
|
||||||
|
|
||||||
# No matching subtask found
|
# No matching subtask found
|
||||||
@@ -288,15 +282,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
if self.temporal_proportions is None or episode_index is None:
|
if self.temporal_proportions is None or episode_index is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
|
||||||
|
|
||||||
# Normalize inputs to numpy arrays
|
# Normalize inputs to numpy arrays
|
||||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||||
|
|
||||||
# Determine sequence length
|
# Determine sequence length
|
||||||
if video_features is not None and video_features.dim() >= 2:
|
if video_features is not None and video_features.dim() >= 2:
|
||||||
seq_len = video_features.shape[1] if is_batch else video_features.shape[0]
|
seq_len = video_features.shape[1]
|
||||||
else:
|
else:
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
|
|
||||||
@@ -315,24 +307,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
all_stage_labels.append(result[0])
|
all_stage_labels.append(result[0])
|
||||||
all_progress_targets.append(result[1])
|
all_progress_targets.append(result[1])
|
||||||
|
|
||||||
if is_batch:
|
|
||||||
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
|
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
|
||||||
return all_stage_labels[0], all_progress_targets[0]
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
"""Encode images, text, and normalize states in the transition."""
|
"""Encode images, text, and normalize states in the transition."""
|
||||||
from lerobot.processor.core import TransitionKey
|
|
||||||
|
|
||||||
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
||||||
|
|
||||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||||
if not isinstance(observation, dict):
|
|
||||||
raise ValueError("Observation must be a dictionary")
|
|
||||||
|
|
||||||
# Encode images with CLIP
|
|
||||||
image = observation.get(self.image_key)
|
image = observation.get(self.image_key)
|
||||||
if image is None:
|
|
||||||
raise ValueError(f"Image not found in observation for key: {self.image_key}")
|
|
||||||
|
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
image = image.cpu().numpy()
|
image = image.cpu().numpy()
|
||||||
@@ -342,10 +325,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
|
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
|
||||||
state_key = self.config.state_key
|
state_key = self.config.state_key
|
||||||
state_data = observation.get(state_key)
|
state_data = observation.get(state_key)
|
||||||
if state_data is None:
|
|
||||||
state_data = observation.get("state") or observation.get("observation.state")
|
|
||||||
if state_data is None:
|
|
||||||
raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')")
|
|
||||||
|
|
||||||
if isinstance(state_data, torch.Tensor):
|
if isinstance(state_data, torch.Tensor):
|
||||||
state_tensor = state_data.float()
|
state_tensor = state_data.float()
|
||||||
@@ -355,8 +334,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||||
|
|
||||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
if not isinstance(comp_data, dict):
|
|
||||||
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
|
|
||||||
|
|
||||||
# Get task description from dataset (complementary_data["task"])
|
# Get task description from dataset (complementary_data["task"])
|
||||||
task = comp_data.get('task')
|
task = comp_data.get('task')
|
||||||
@@ -378,18 +355,17 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
|
|
||||||
# Compute episode metadata if dataset_meta is available
|
# Compute episode metadata if dataset_meta is available
|
||||||
if self.dataset_meta is not None:
|
if self.dataset_meta is not None:
|
||||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
|
||||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||||
|
|
||||||
# Determine number of frames from video features
|
# Determine number of frames from video features
|
||||||
if video_features.dim() >= 2:
|
if video_features.dim() >= 2:
|
||||||
num_frames = video_features.shape[1] if is_batch else video_features.shape[0]
|
num_frames = video_features.shape[1]
|
||||||
else:
|
else:
|
||||||
num_frames = 1
|
num_frames = 1
|
||||||
|
|
||||||
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
|
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
|
||||||
frame_indices, episode_indices, num_frames, is_batch
|
frame_indices, episode_indices, num_frames
|
||||||
)
|
)
|
||||||
observation['absolute_frame_indices'] = abs_indices
|
observation['absolute_frame_indices'] = abs_indices
|
||||||
observation['remaining_length'] = remaining
|
observation['remaining_length'] = remaining
|
||||||
@@ -412,24 +388,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
"""Encode a batch of images using CLIP.
|
"""Encode a batch of images using CLIP.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
images: Batched images with shape:
|
images: Batched images with shape: (B, T, C, H, W)
|
||||||
- (B, C, H, W) for single frames, or
|
|
||||||
- (B, T, C, H, W) for temporal sequences
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Encoded feature vectors with shape (B, 512) or (B, T, 512)
|
Encoded feature vectors with shape (B, T, 512)
|
||||||
"""
|
"""
|
||||||
# Check if we have temporal dimension
|
|
||||||
has_temporal = len(images.shape) == 5
|
|
||||||
|
|
||||||
if has_temporal: # Shape: (B, T, C, H, W)
|
|
||||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||||
elif len(images.shape) == 4: # Shape: (B, C, H, W)
|
|
||||||
batch_size = images.shape[0]
|
|
||||||
seq_length = 1
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}")
|
|
||||||
|
|
||||||
# Convert to list of PIL images
|
# Convert to list of PIL images
|
||||||
num_frames = images.shape[0]
|
num_frames = images.shape[0]
|
||||||
@@ -470,8 +436,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
# Concatenate all embeddings
|
# Concatenate all embeddings
|
||||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||||
|
|
||||||
# Reshape back if temporal
|
# Reshape back
|
||||||
if has_temporal:
|
|
||||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||||
|
|
||||||
return all_embeddings
|
return all_embeddings
|
||||||
@@ -495,10 +460,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
# Get text features from CLIP
|
# Get text features from CLIP
|
||||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||||
|
|
||||||
# Handle single text case
|
|
||||||
if text_embedding.dim() == 1:
|
|
||||||
text_embedding = text_embedding.unsqueeze(0)
|
|
||||||
|
|
||||||
# Replicate for batch (B, 512)
|
# Replicate for batch (B, 512)
|
||||||
text_embedding = text_embedding.expand(batch_size, -1)
|
text_embedding = text_embedding.expand(batch_size, -1)
|
||||||
|
|
||||||
|
|||||||
@@ -14,25 +14,31 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
|
||||||
Utility functions for SARM progress label computation.
|
|
||||||
|
|
||||||
Implements formulas from the SARM paper:
|
|
||||||
- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k
|
|
||||||
- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Sequence
|
from typing import Sequence, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# Pydantic Models for SARM-style Annotation
|
||||||
|
class Timestamp(BaseModel):
|
||||||
|
"""Timestamp in MM:SS or SS format"""
|
||||||
|
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
||||||
|
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
||||||
|
|
||||||
|
|
||||||
def compute_priors(
|
class Subtask(BaseModel):
|
||||||
subtask_durations_per_trajectory: dict[str, list[float]],
|
"""Individual subtask/stage - must use EXACT names from provided list"""
|
||||||
trajectory_lengths: dict[str, list[float]],
|
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
||||||
subtask_names: list[str],
|
timestamps: Timestamp
|
||||||
) -> dict[str, float]:
|
|
||||||
|
|
||||||
|
class SubtaskAnnotation(BaseModel):
|
||||||
|
"""Complete annotation for a robot manipulation episode"""
|
||||||
|
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Compute dataset-level temporal proportions (priors) for each subtask.
|
Compute dataset-level temporal proportions (priors) for each subtask.
|
||||||
|
|
||||||
@@ -40,61 +46,64 @@ def compute_priors(
|
|||||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||||
|
|
||||||
where:
|
where:
|
||||||
- M is the number of trajectories
|
- M is the number of trajectories (episodes)
|
||||||
- L_{i,k} is the length of subtask k in trajectory i
|
- L_{i,k} is the duration of subtask k in trajectory i
|
||||||
- T_i is the total length of trajectory i
|
- T_i is the total duration of trajectory i
|
||||||
|
|
||||||
This averages the PROPORTION of each subtask within each trajectory,
|
This averages the PROPORTION of each subtask within each trajectory,
|
||||||
giving equal weight to all trajectories regardless of their absolute length.
|
giving equal weight to all trajectories regardless of their absolute length.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subtask_durations_per_trajectory: Dict mapping subtask name to list of
|
annotations: Dict mapping episode index to SubtaskAnnotation object.
|
||||||
(duration, trajectory_length) tuples for each occurrence
|
Each annotation has a .subtasks list where each subtask has:
|
||||||
trajectory_lengths: Dict mapping subtask name to list of trajectory lengths
|
- .name: subtask name
|
||||||
for each occurrence of that subtask
|
- .timestamps.start: start time as "MM:SS" string
|
||||||
subtask_names: Ordered list of subtask names
|
- .timestamps.end: end time as "MM:SS" string
|
||||||
|
fps: Frames per second (unused, kept for API compatibility)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping subtask name to its temporal proportion (ᾱ_k)
|
Dict mapping subtask name to its temporal proportion (ᾱ_k).
|
||||||
|
Proportions are normalized to sum to 1.0.
|
||||||
"""
|
"""
|
||||||
if not subtask_names:
|
subtask_proportions: dict[str, list[float]] = {}
|
||||||
raise ValueError("subtask_names cannot be empty")
|
|
||||||
|
|
||||||
# Compute proportion per occurrence: L_{i,k} / T_i
|
for annotation in annotations.values():
|
||||||
subtask_proportions = {}
|
total_duration = 0
|
||||||
for name in subtask_names:
|
durations: dict[str, int] = {}
|
||||||
if name in subtask_durations_per_trajectory and name in trajectory_lengths:
|
|
||||||
durations = subtask_durations_per_trajectory[name]
|
|
||||||
traj_lengths = trajectory_lengths[name]
|
|
||||||
|
|
||||||
if len(durations) != len(traj_lengths):
|
for subtask in annotation.subtasks:
|
||||||
raise ValueError(
|
start_parts = subtask.timestamps.start.split(":")
|
||||||
f"Mismatch in lengths for subtask '{name}': "
|
end_parts = subtask.timestamps.end.split(":")
|
||||||
f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute L_{i,k} / T_i for each occurrence
|
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
|
||||||
proportions = []
|
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
|
||||||
for duration, traj_len in zip(durations, traj_lengths):
|
|
||||||
if traj_len > 0:
|
|
||||||
proportions.append(duration / traj_len)
|
|
||||||
|
|
||||||
# Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i)
|
duration = end_seconds - start_seconds
|
||||||
subtask_proportions[name] = np.mean(proportions) if proportions else 0.0
|
durations[subtask.name] = duration
|
||||||
else:
|
total_duration += duration
|
||||||
subtask_proportions[name] = 0.0
|
|
||||||
|
|
||||||
# Normalize to ensure sum = 1 (handles floating point errors and missing subtasks)
|
# Calculate L_{i,k} / T_i for each subtask in this trajectory
|
||||||
total = sum(subtask_proportions.values())
|
if total_duration > 0:
|
||||||
if total > 0:
|
for name, duration in durations.items():
|
||||||
subtask_proportions = {
|
if name not in subtask_proportions:
|
||||||
name: prop / total for name, prop in subtask_proportions.items()
|
subtask_proportions[name] = []
|
||||||
|
subtask_proportions[name].append(duration / total_duration)
|
||||||
|
|
||||||
|
if not subtask_proportions:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
|
||||||
|
avg_proportions = {
|
||||||
|
name: sum(props) / len(props)
|
||||||
|
for name, props in subtask_proportions.items()
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
raise ValueError("Cannot compute temporal proportions: all proportions are zero. "
|
|
||||||
"Check that your dataset has valid subtask annotations with start/end times.")
|
|
||||||
|
|
||||||
return subtask_proportions
|
# Normalize to ensure sum = 1
|
||||||
|
total = sum(avg_proportions.values())
|
||||||
|
if total > 0:
|
||||||
|
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
||||||
|
|
||||||
|
return avg_proportions
|
||||||
|
|
||||||
|
|
||||||
def compute_tau(
|
def compute_tau(
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
Tests for SARM utility functions.
|
Tests for SARM utility functions.
|
||||||
|
|
||||||
Tests the implementation of SARM paper formulas:
|
Tests the implementation of SARM paper formulas:
|
||||||
- Formula (1): compute_priors - dataset-level temporal proportions
|
- Formula (1): compute_temporal_proportions - dataset-level temporal proportions
|
||||||
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
|
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -26,15 +26,31 @@ import pytest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||||
from lerobot.policies.sarm.sarm_utils import (
|
from lerobot.policies.sarm.sarm_utils import (
|
||||||
compute_priors,
|
compute_temporal_proportions,
|
||||||
compute_tau,
|
compute_tau,
|
||||||
compute_cumulative_progress_batch,
|
compute_cumulative_progress_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
|
||||||
|
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
|
||||||
|
return SubtaskAnnotation(
|
||||||
|
subtasks=[
|
||||||
|
Subtask(
|
||||||
|
name=name,
|
||||||
|
timestamps=Timestamp(
|
||||||
|
start=f"{start // 60:02d}:{start % 60:02d}",
|
||||||
|
end=f"{end // 60:02d}:{end % 60:02d}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for name, start, end in subtasks
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
class TestComputePriors:
|
|
||||||
"""Tests for compute_priors (SARM Paper Formula 1).
|
class TestComputeTemporalProportions:
|
||||||
|
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
|
||||||
|
|
||||||
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||||
|
|
||||||
@@ -45,42 +61,34 @@ class TestComputePriors:
|
|||||||
def test_basic_two_trajectories_equal_proportions(self):
|
def test_basic_two_trajectories_equal_proportions(self):
|
||||||
"""Test with two trajectories that have equal proportions."""
|
"""Test with two trajectories that have equal proportions."""
|
||||||
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
||||||
subtask_durations = {
|
# Traj 1: T=100s, subtask1=50s, subtask2=50s
|
||||||
'subtask1': [50, 100], # durations
|
# Traj 2: T=200s, subtask1=100s, subtask2=100s
|
||||||
'subtask2': [50, 100],
|
annotations = {
|
||||||
|
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
|
||||||
|
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
|
||||||
}
|
}
|
||||||
trajectory_lengths = {
|
|
||||||
'subtask1': [100, 200],
|
|
||||||
'subtask2': [100, 200],
|
|
||||||
}
|
|
||||||
subtask_names = ['subtask1', 'subtask2']
|
|
||||||
|
|
||||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
# Both should be 0.5
|
# Both should be 0.5
|
||||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||||
|
|
||||||
def test_paper_example_different_from_avg_durations(self):
|
def test_paper_example_different_from_avg_durations(self):
|
||||||
"""Test that compute_priors differs from naive average duration approach.
|
"""Test that compute_temporal_proportions differs from naive average duration approach.
|
||||||
|
|
||||||
This is the key test showing the difference between:
|
This is the key test showing the difference between:
|
||||||
- Paper formula: average of (L_i,k / T_i)
|
- Paper formula: average of (L_i,k / T_i)
|
||||||
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
||||||
"""
|
"""
|
||||||
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
|
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
|
||||||
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
|
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
|
||||||
subtask_durations = {
|
annotations = {
|
||||||
'subtask1': [80, 40],
|
0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
|
||||||
'subtask2': [20, 160],
|
1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
|
||||||
}
|
}
|
||||||
trajectory_lengths = {
|
|
||||||
'subtask1': [100, 200],
|
|
||||||
'subtask2': [100, 200],
|
|
||||||
}
|
|
||||||
subtask_names = ['subtask1', 'subtask2']
|
|
||||||
|
|
||||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
# Paper formula:
|
# Paper formula:
|
||||||
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
||||||
@@ -88,22 +96,14 @@ class TestComputePriors:
|
|||||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
def test_single_trajectory(self):
|
def test_single_trajectory(self):
|
||||||
"""Test with a single trajectory."""
|
"""Test with a single trajectory."""
|
||||||
subtask_durations = {
|
# T=100s, reach=30s, grasp=20s, lift=50s
|
||||||
'reach': [30],
|
annotations = {
|
||||||
'grasp': [20],
|
0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
|
||||||
'lift': [50],
|
|
||||||
}
|
}
|
||||||
trajectory_lengths = {
|
|
||||||
'reach': [100],
|
|
||||||
'grasp': [100],
|
|
||||||
'lift': [100],
|
|
||||||
}
|
|
||||||
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
|
|
||||||
|
|
||||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
assert abs(result['reach'] - 0.3) < 1e-6
|
assert abs(result['reach'] - 0.3) < 1e-6
|
||||||
assert abs(result['grasp'] - 0.2) < 1e-6
|
assert abs(result['grasp'] - 0.2) < 1e-6
|
||||||
@@ -111,49 +111,35 @@ class TestComputePriors:
|
|||||||
|
|
||||||
def test_sum_to_one(self):
|
def test_sum_to_one(self):
|
||||||
"""Test that proportions always sum to 1."""
|
"""Test that proportions always sum to 1."""
|
||||||
subtask_durations = {
|
# Three episodes with varying proportions
|
||||||
'a': [10, 20, 30],
|
annotations = {
|
||||||
'b': [40, 50, 60],
|
0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
|
||||||
'c': [50, 30, 10],
|
1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
|
||||||
|
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
|
||||||
}
|
}
|
||||||
trajectory_lengths = {
|
|
||||||
'a': [100, 100, 100],
|
|
||||||
'b': [100, 100, 100],
|
|
||||||
'c': [100, 100, 100],
|
|
||||||
}
|
|
||||||
subtask_names = ['a', 'b', 'c']
|
|
||||||
|
|
||||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
total = sum(result.values())
|
total = sum(result.values())
|
||||||
assert abs(total - 1.0) < 1e-6
|
assert abs(total - 1.0) < 1e-6
|
||||||
|
|
||||||
def test_empty_subtask_names_raises(self):
|
def test_empty_annotations_returns_empty(self):
|
||||||
"""Test that empty subtask_names raises an error."""
|
"""Test that empty annotations returns empty dict."""
|
||||||
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
|
result = compute_temporal_proportions({})
|
||||||
compute_priors({}, {}, [])
|
assert result == {}
|
||||||
|
|
||||||
def test_missing_subtask_gets_zero_before_normalization(self):
|
def test_uniform_proportions(self):
|
||||||
"""Test handling of subtasks that appear in some but not all trajectories."""
|
"""Test with uniform proportions across subtasks."""
|
||||||
# subtask1 appears in both, subtask2 only in first
|
# Each subtask takes 25% of each episode
|
||||||
subtask_durations = {
|
annotations = {
|
||||||
'subtask1': [50, 100],
|
0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
|
||||||
'subtask2': [50], # only in first trajectory
|
1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
|
||||||
}
|
}
|
||||||
trajectory_lengths = {
|
|
||||||
'subtask1': [100, 200],
|
|
||||||
'subtask2': [100],
|
|
||||||
}
|
|
||||||
subtask_names = ['subtask1', 'subtask2']
|
|
||||||
|
|
||||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
|
for name in ['a', 'b', 'c', 'd']:
|
||||||
# subtask2: 50/100 = 0.5 (only one occurrence)
|
assert abs(result[name] - 0.25) < 1e-6
|
||||||
# After normalization: both should be 0.5
|
|
||||||
assert result['subtask1'] > 0
|
|
||||||
assert result['subtask2'] > 0
|
|
||||||
assert abs(sum(result.values()) - 1.0) < 1e-6
|
|
||||||
|
|
||||||
|
|
||||||
class TestComputeTau:
|
class TestComputeTau:
|
||||||
|
|||||||
Reference in New Issue
Block a user