mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 20:27:05 +00:00
cleanup and refactor
This commit is contained in:
@@ -25,14 +25,7 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
SARM is a dual-head reward model that jointly predicts:
|
||||
1. High-level task stage (classification)
|
||||
2. Fine-grained progress within each stage (regression)
|
||||
|
||||
It uses CLIP for visual encoding and supports joint state input.
|
||||
"""
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
|
||||
|
||||
# Visual encoding parameters
|
||||
image_dim: int = 512 # CLIP embedding dimension
|
||||
@@ -40,21 +33,20 @@ class SARMConfig(PreTrainedConfig):
|
||||
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
|
||||
|
||||
# Text encoding parameters
|
||||
text_dim: int = 384 # MiniLM embedding dimension
|
||||
text_dim: int = 384
|
||||
|
||||
# Joint state parameters
|
||||
state_dim: int | None = None # Auto-detected from dataset if None
|
||||
use_joint_state: bool = True # Whether to use joint state input
|
||||
|
||||
# Architecture parameters
|
||||
hidden_dim: int = 768 # Transformer hidden dimension
|
||||
num_heads: int = 12 # Number of attention heads
|
||||
num_layers: int = 8 # Number of transformer layers
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
|
||||
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
|
||||
|
||||
# Temporal parameters
|
||||
max_length: int = 9 # Maximum video sequence length (should match num_frames)
|
||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||
sampling_mode: str = "sarm" # Sampling mode: "sarm" or "rewind"
|
||||
|
||||
@@ -65,24 +57,13 @@ class SARMConfig(PreTrainedConfig):
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
enable_rabc: bool = False # Enable RA-BC weighted loss
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant to avoid division by zero
|
||||
chunk_length: int = 25 # Action chunk length for computing progress deltas
|
||||
|
||||
# Model loading
|
||||
pretrained_model_path: str | None = None
|
||||
|
||||
# Device settings
|
||||
device: str | None = None
|
||||
|
||||
# Processor settings
|
||||
image_key: str = "observation.images.top" # Key for images in dataset
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
task_description: str = "perform the task" # Default task description
|
||||
encode_on_the_fly: bool = True # Encode images/text during training
|
||||
use_dataset_task: bool = True # Use task descriptions from dataset
|
||||
use_subtask_annotations: bool = True # Use subtask annotations for stage-aware training if available
|
||||
|
||||
# Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors
|
||||
input_features: dict = field(default_factory=lambda: {
|
||||
@@ -122,7 +103,7 @@ class SARMConfig(PreTrainedConfig):
|
||||
|
||||
if self.sampling_mode not in ["sarm", "rewind", "custom"]:
|
||||
raise ValueError(
|
||||
f"sampling_mode must be 'sarm', 'rewind', or 'custom', got {self.sampling_mode}"
|
||||
f"sampling_mode must be 'sarm' or 'rewind', got {self.sampling_mode}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
|
||||
@@ -63,33 +63,30 @@ class SARMTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_dim: int = 512, # CLIP dimension
|
||||
text_dim: int = 384, # MiniLM dimension
|
||||
state_dim: int = 14, # Joint state dimension
|
||||
video_dim: int = 512,
|
||||
text_dim: int = 384,
|
||||
state_dim: int = 14,
|
||||
hidden_dim: int = 768,
|
||||
num_heads: int = 12,
|
||||
num_layers: int = 8,
|
||||
num_stages: int = 5,
|
||||
max_length: int = 9,
|
||||
dropout: float = 0.1,
|
||||
use_joint_state: bool = True
|
||||
dropout: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_length = max_length
|
||||
self.num_stages = num_stages
|
||||
self.use_joint_state = use_joint_state
|
||||
|
||||
# Project video, text, and state to common dimension
|
||||
# Project video, text, and state to same dimension
|
||||
self.video_proj = nn.Linear(video_dim, hidden_dim)
|
||||
self.text_proj = nn.Linear(text_dim, hidden_dim)
|
||||
if use_joint_state:
|
||||
self.state_proj = nn.Linear(state_dim, hidden_dim)
|
||||
self.state_proj = nn.Linear(state_dim, hidden_dim)
|
||||
|
||||
# Position embedding only for the first frame
|
||||
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
|
||||
|
||||
# Transformer encoder (shared backbone)
|
||||
# Transformer encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=num_heads,
|
||||
@@ -157,12 +154,10 @@ class SARMTransformer(nn.Module):
|
||||
# Project inputs to common dimension
|
||||
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
|
||||
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
|
||||
|
||||
# Add joint state if provided
|
||||
if self.use_joint_state and state_features is not None:
|
||||
state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim]
|
||||
# Fuse video and state features (simple addition)
|
||||
video_embed = video_embed + state_embed
|
||||
|
||||
state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim]
|
||||
# Fuse video and state features (simple addition)
|
||||
video_embed = video_embed + state_embed
|
||||
|
||||
# Add positional embedding to first video frame
|
||||
video_embed[:, 0] += self.first_pos_embed
|
||||
@@ -274,16 +269,10 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
num_layers=config.num_layers,
|
||||
num_stages=config.num_stages,
|
||||
max_length=config.max_length,
|
||||
dropout=config.dropout,
|
||||
use_joint_state=config.use_joint_state
|
||||
dropout=config.dropout
|
||||
)
|
||||
self.sarm_transformer.to(self.device)
|
||||
|
||||
# RA-BC running statistics (for weighted loss)
|
||||
if config.enable_rabc:
|
||||
self.register_buffer("rabc_mean", torch.tensor(0.0))
|
||||
self.register_buffer("rabc_m2", torch.tensor(0.0))
|
||||
self.register_buffer("rabc_count", torch.tensor(0))
|
||||
|
||||
logging.info(f"SARM Reward Model initialized on {self.device}")
|
||||
|
||||
@@ -474,40 +463,6 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
|
||||
return rewards
|
||||
|
||||
def _update_rabc_stats(self, progress_deltas: torch.Tensor):
|
||||
"""Update running statistics for RA-BC using Welford's online algorithm."""
|
||||
if not self.config.enable_rabc:
|
||||
return
|
||||
|
||||
for delta in progress_deltas:
|
||||
self.rabc_count += 1
|
||||
delta_val = delta.item()
|
||||
delta_mean = delta_val - self.rabc_mean
|
||||
self.rabc_mean += delta_mean / self.rabc_count
|
||||
delta_m2 = delta_val - self.rabc_mean
|
||||
self.rabc_m2 += delta_mean * delta_m2
|
||||
|
||||
def _compute_rabc_weights(self, progress_deltas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute RA-BC weights for progress deltas."""
|
||||
if not self.config.enable_rabc or self.rabc_count < 2:
|
||||
return torch.ones_like(progress_deltas)
|
||||
|
||||
# Get running statistics
|
||||
mean = max(self.rabc_mean.item(), 0.0) # Clamp mean to non-negative
|
||||
variance = self.rabc_m2 / (self.rabc_count - 1)
|
||||
std = torch.sqrt(variance).item()
|
||||
|
||||
# Compute soft weights
|
||||
lower_bound = mean - 2 * std
|
||||
upper_bound = mean + 2 * std
|
||||
weights = (progress_deltas - lower_bound) / (4 * std + self.config.rabc_epsilon)
|
||||
weights = torch.clamp(weights, 0.0, 1.0)
|
||||
|
||||
# Apply hard threshold
|
||||
high_quality_mask = progress_deltas > self.config.rabc_kappa
|
||||
weights = torch.where(high_quality_mask, torch.ones_like(weights), weights)
|
||||
|
||||
return weights
|
||||
|
||||
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
|
||||
"""Load pretrained model weights from a checkpoint file."""
|
||||
@@ -565,274 +520,169 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _get_remaining_length(self, observation: dict, idx: int) -> float | None:
|
||||
"""Extract remaining length for a sample from observation metadata."""
|
||||
remaining_lengths = observation.get('remaining_length')
|
||||
if remaining_lengths is None:
|
||||
return None
|
||||
if isinstance(remaining_lengths, torch.Tensor):
|
||||
return remaining_lengths[idx].item() if remaining_lengths.dim() > 0 else remaining_lengths.item()
|
||||
return remaining_lengths
|
||||
|
||||
def _compute_progress_targets(self, remaining_length: float | None, seq_len: int) -> torch.Tensor:
|
||||
"""Compute progress targets based on remaining trajectory length."""
|
||||
if remaining_length is not None and remaining_length > 0:
|
||||
return torch.arange(1, seq_len + 1, dtype=torch.float32, device=self.device) / remaining_length
|
||||
else:
|
||||
raise ValueError("Remaining length is None, but is required for progress targets")
|
||||
|
||||
def _apply_rewind_augmentation(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
progress: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
max_length: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Apply rewind augmentation: append 2-4 reversed frames (SARM paper)."""
|
||||
num_reverse = random.randint(2, min(4, max_length - 1))
|
||||
|
||||
# Reverse and take frames (skip first which is last of original)
|
||||
reversed_video = video.flip(0)[1:num_reverse + 1]
|
||||
reversed_progress = progress.flip(0)[1:num_reverse + 1]
|
||||
|
||||
# Concatenate and trim
|
||||
video = torch.cat([video, reversed_video], dim=0)[:max_length]
|
||||
progress = torch.cat([progress, reversed_progress], dim=0)[:max_length]
|
||||
|
||||
if state is not None:
|
||||
reversed_state = state.flip(0)[1:num_reverse + 1]
|
||||
state = torch.cat([state, reversed_state], dim=0)[:max_length]
|
||||
|
||||
return video, progress, state
|
||||
|
||||
def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
"""Pad or trim tensor to target length."""
|
||||
current_len = tensor.shape[0]
|
||||
if current_len == target_len:
|
||||
return tensor
|
||||
if current_len < target_len:
|
||||
padding = target_len - current_len
|
||||
return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])])
|
||||
return tensor[:target_len]
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Forward pass compatible with lerobot training pipeline.
|
||||
Forward pass for SARM reward model training.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing observation with:
|
||||
- 'video_features': Pre-encoded video features (B, T, 512)
|
||||
- 'text_features': Pre-encoded text features (B, 384)
|
||||
- 'state_features': Joint state features (B, T, state_dim)
|
||||
batch: Dictionary with 'observation' containing:
|
||||
- 'video_features': (B, T, 512) pre-encoded video features
|
||||
- 'text_features': (B, 384) pre-encoded text features
|
||||
- 'state_features': (B, T, state_dim) joint state features
|
||||
- 'remaining_length': (B,) remaining trajectory lengths (optional)
|
||||
- 'stage_labels': (B, T) stage labels (optional, from annotations)
|
||||
- 'progress_targets': (B, T, 1) progress targets (optional, from annotations)
|
||||
|
||||
Returns:
|
||||
loss: Total training loss
|
||||
output_dict: Dictionary of loss components for logging
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
"""
|
||||
# Extract from observation dict
|
||||
observation = batch.get('observation', batch)
|
||||
|
||||
# Extract required features
|
||||
video_features = observation['video_features'].to(self.device)
|
||||
text_features = observation['text_features'].to(self.device)
|
||||
state_features = observation.get('state_features', None)
|
||||
state_features = observation.get('state_features')
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(self.device)
|
||||
|
||||
# Extract stage labels and progress targets if available (from subtask annotations)
|
||||
stage_labels = observation.get('stage_labels', None)
|
||||
if stage_labels is not None:
|
||||
stage_labels = stage_labels.to(self.device)
|
||||
|
||||
progress_targets_from_annotations = observation.get('progress_targets', None)
|
||||
if progress_targets_from_annotations is not None:
|
||||
progress_targets_from_annotations = progress_targets_from_annotations.to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
max_length = self.config.num_frames
|
||||
|
||||
# Handle both single frames and sequences
|
||||
# Ensure 3D video features (B, T, D)
|
||||
if video_features.dim() == 2:
|
||||
# Single frames: replicate to create pseudo-sequences
|
||||
video_features = video_features.unsqueeze(1).repeat(1, max_length, 1)
|
||||
|
||||
video_features = video_features.unsqueeze(1).expand(-1, max_length, -1)
|
||||
if state_features is not None and state_features.dim() == 2:
|
||||
# Single state: replicate to match sequence length
|
||||
state_features = state_features.unsqueeze(1).repeat(1, max_length, 1)
|
||||
|
||||
# Apply rewind augmentation (following SARM paper: up to 4 reversed frames)
|
||||
# Note: video_features are already sampled by dataset (9 frames with 30-frame gaps)
|
||||
# We just need to compute progress targets and optionally apply rewind
|
||||
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
|
||||
|
||||
# Process each sample: compute progress targets and apply rewind augmentation
|
||||
processed_videos = []
|
||||
processed_states = []
|
||||
progress_targets = []
|
||||
|
||||
# Extract episode metadata for correct progress normalization
|
||||
absolute_frame_indices = observation.get('absolute_frame_indices', None)
|
||||
episode_lengths = observation.get('episode_length', None)
|
||||
remaining_lengths = observation.get('remaining_length', None)
|
||||
|
||||
for i in range(batch_size):
|
||||
# Get metadata for this sample
|
||||
current_absolute_indices = None
|
||||
current_episode_length = None
|
||||
current_remaining_length = None
|
||||
remaining_length = self._get_remaining_length(observation, i)
|
||||
progress = self._compute_progress_targets(remaining_length, max_length)
|
||||
|
||||
if absolute_frame_indices is not None:
|
||||
if isinstance(absolute_frame_indices, list):
|
||||
current_absolute_indices = absolute_frame_indices[i]
|
||||
else:
|
||||
current_absolute_indices = absolute_frame_indices
|
||||
video = video_features[i]
|
||||
state = state_features[i] if state_features is not None else None
|
||||
|
||||
if episode_lengths is not None:
|
||||
if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0:
|
||||
current_episode_length = episode_lengths[i].item()
|
||||
else:
|
||||
current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths
|
||||
|
||||
if remaining_lengths is not None:
|
||||
if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0:
|
||||
current_remaining_length = remaining_lengths[i].item()
|
||||
else:
|
||||
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
|
||||
|
||||
# Compute progress targets directly from metadata (frames already loaded by dataset)
|
||||
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
||||
if current_remaining_length is not None and current_remaining_length > 0:
|
||||
# Correct: relative progress from first loaded frame to episode end
|
||||
progress_indices = torch.arange(1, max_length + 1, dtype=torch.float32, device=self.device)
|
||||
progress = progress_indices / current_remaining_length
|
||||
else:
|
||||
# Fallback: linear progress (when metadata is not available)
|
||||
logging.warning(f"Sample {i}: No remaining_length metadata, using linear progress fallback")
|
||||
progress = torch.linspace(1.0/max_length, 1.0, max_length, device=self.device)
|
||||
|
||||
# Apply rewind augmentation with 50% probability (following SARM paper)
|
||||
# Paper specifies: "appending up to four frames from earlier timestamps with reversed order"
|
||||
# Apply rewind augmentation with 50% probability (SARM paper)
|
||||
if random.random() < 0.5:
|
||||
# Rewind: append 2-4 reversed frames, trim to max_length
|
||||
num_reverse = random.randint(2, min(4, max_length - 1))
|
||||
|
||||
# Reverse video and progress
|
||||
reversed_video = video_features[i].flip(0)
|
||||
reversed_progress = progress.flip(0)
|
||||
|
||||
# Take frames from reversed (skip first which is last of original)
|
||||
reverse_frames = reversed_video[1:num_reverse+1]
|
||||
reverse_progress = reversed_progress[1:num_reverse+1]
|
||||
|
||||
# Concatenate forward + reversed
|
||||
rewound_video = torch.cat([video_features[i], reverse_frames], dim=0)
|
||||
rewound_progress = torch.cat([progress, reverse_progress], dim=0)
|
||||
|
||||
# Trim to max_length
|
||||
rewound_video = rewound_video[:max_length]
|
||||
rewound_progress = rewound_progress[:max_length]
|
||||
|
||||
processed_videos.append(rewound_video)
|
||||
progress_targets.append(rewound_progress)
|
||||
|
||||
# Process state features if available
|
||||
if state_features is not None:
|
||||
reversed_state = state_features[i].flip(0)
|
||||
reverse_state_frames = reversed_state[1:num_reverse+1]
|
||||
rewound_state = torch.cat([state_features[i], reverse_state_frames], dim=0)
|
||||
rewound_state = rewound_state[:max_length]
|
||||
processed_states.append(rewound_state)
|
||||
else:
|
||||
# Normal: use frames as-is with forward progress
|
||||
processed_videos.append(video_features[i])
|
||||
progress_targets.append(progress)
|
||||
|
||||
# Process state features if available
|
||||
if state_features is not None:
|
||||
processed_states.append(state_features[i])
|
||||
video, progress, state = self._apply_rewind_augmentation(video, progress, state, max_length)
|
||||
|
||||
# Ensure correct sequence length
|
||||
video = self._ensure_sequence_length(video, max_length)
|
||||
progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1)
|
||||
if state is not None:
|
||||
state = self._ensure_sequence_length(state, max_length)
|
||||
|
||||
processed_videos.append(video)
|
||||
progress_targets.append(progress)
|
||||
if state is not None:
|
||||
processed_states.append(state)
|
||||
|
||||
# Ensure all sequences have the same length before stacking
|
||||
# (sampling functions should return max_length, but double-check)
|
||||
validated_videos = []
|
||||
validated_progress = []
|
||||
for i, (vid, prog) in enumerate(zip(processed_videos, progress_targets)):
|
||||
if len(vid) != max_length:
|
||||
logging.warning(f"Sample {i}: video length {len(vid)} != {max_length}, padding/trimming")
|
||||
if len(vid) < max_length:
|
||||
# Pad
|
||||
padding = max_length - len(vid)
|
||||
vid = torch.cat([vid, vid[-1:].repeat(padding, 1)])
|
||||
prog = torch.cat([prog, torch.full((padding,), prog[-1], device=prog.device)])
|
||||
else:
|
||||
# Trim
|
||||
vid = vid[:max_length]
|
||||
prog = prog[:max_length]
|
||||
validated_videos.append(vid)
|
||||
validated_progress.append(prog)
|
||||
# Stack into batches
|
||||
processed_videos = torch.stack(processed_videos)
|
||||
progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1)
|
||||
processed_states = torch.stack(processed_states) if processed_states else None
|
||||
|
||||
# Stack processed features
|
||||
processed_videos = torch.stack(validated_videos)
|
||||
progress_targets = torch.stack(validated_progress)
|
||||
|
||||
# Ensure progress_targets has the same shape as progress_preds
|
||||
# progress_preds is (batch_size, num_frames, 1)
|
||||
# progress_targets is (batch_size, num_frames) -> add last dimension
|
||||
if progress_targets.dim() == 2:
|
||||
progress_targets = progress_targets.unsqueeze(-1) # (batch_size, num_frames, 1)
|
||||
|
||||
if state_features is not None and len(processed_states) > 0:
|
||||
processed_states = torch.stack(processed_states)
|
||||
else:
|
||||
processed_states = None
|
||||
|
||||
# Get predictions
|
||||
# Get model predictions
|
||||
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
|
||||
processed_videos, text_features, processed_states
|
||||
)
|
||||
|
||||
# Use annotation-based progress targets if available, otherwise use computed ones
|
||||
if progress_targets_from_annotations is not None and len(processed_videos) == 1:
|
||||
# Use refined progress from subtask annotations (single sample case)
|
||||
# Ensure shapes match
|
||||
if progress_targets_from_annotations.shape != progress_preds.shape:
|
||||
if progress_targets_from_annotations.dim() == 2:
|
||||
progress_targets_from_annotations = progress_targets_from_annotations.unsqueeze(0)
|
||||
progress_targets = progress_targets_from_annotations
|
||||
# Use annotation-based progress targets
|
||||
progress_from_annotations = observation.get('progress_targets')
|
||||
if progress_from_annotations is not None:
|
||||
progress_from_annotations = progress_from_annotations.to(self.device)
|
||||
if progress_from_annotations.dim() == 2:
|
||||
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
|
||||
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
||||
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
||||
progress_targets = progress_from_annotations
|
||||
|
||||
# Compute progress loss using targets
|
||||
# Compute progress loss
|
||||
progress_loss = F.mse_loss(progress_preds, progress_targets)
|
||||
output_dict = {'progress_loss': progress_loss.item()}
|
||||
total_loss = progress_loss
|
||||
|
||||
# Compute stage loss if labels are available
|
||||
stage_loss = None
|
||||
if stage_labels is not None and len(processed_videos) == 1:
|
||||
# Ensure stage_labels matches the sequence length
|
||||
if stage_labels.dim() == 1 and stage_logits.dim() == 3:
|
||||
# stage_labels: (seq_len,) -> need to expand to (batch, seq_len)
|
||||
stage_labels = stage_labels.unsqueeze(0).expand(stage_logits.shape[0], -1)
|
||||
elif stage_labels.shape[0] != stage_logits.shape[0]:
|
||||
# Single label for batch - expand
|
||||
stage_labels = stage_labels.expand(stage_logits.shape[0], stage_logits.shape[1])
|
||||
|
||||
# Compute cross-entropy loss for stage classification
|
||||
# Compute stage loss if labels available
|
||||
stage_labels = observation.get('stage_labels')
|
||||
if stage_labels is not None:
|
||||
stage_labels = stage_labels.to(self.device)
|
||||
if stage_labels.dim() == 1:
|
||||
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
|
||||
stage_loss = compute_stage_loss(stage_logits, stage_labels)
|
||||
|
||||
# Combine losses
|
||||
if stage_loss is not None:
|
||||
total_loss = progress_loss + self.config.stage_loss_weight * stage_loss
|
||||
output_dict = {
|
||||
'progress_loss': progress_loss.item(),
|
||||
'stage_loss': stage_loss.item(),
|
||||
}
|
||||
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
|
||||
output_dict['stage_loss'] = stage_loss.item()
|
||||
else:
|
||||
total_loss = progress_loss
|
||||
output_dict = {
|
||||
'progress_loss': progress_loss.item(),
|
||||
}
|
||||
raise ValueError("Stage labels are None, but are required for stage loss")
|
||||
|
||||
# Compute misaligned loss (following SARM paper and ReWiND)
|
||||
# "To improve video-language alignment, task descriptions are occasionally perturbed"
|
||||
if random.random() < 0.2: # 20% probability (matching ReWiND)
|
||||
# Create misaligned pairs by shuffling text features
|
||||
# Misaligned loss: 20% probability (SARM paper - improve video-language alignment)
|
||||
if random.random() < 0.2:
|
||||
shuffle_idx = torch.randperm(batch_size, device=self.device)
|
||||
misaligned_texts = text_features[shuffle_idx]
|
||||
|
||||
# Get predictions for misaligned pairs (should predict zero progress)
|
||||
_, _, misaligned_preds = self.sarm_transformer(
|
||||
processed_videos, misaligned_texts, processed_states
|
||||
processed_videos, text_features[shuffle_idx], processed_states
|
||||
)
|
||||
|
||||
# Target is zero progress for misaligned pairs
|
||||
target_zeros = torch.zeros_like(misaligned_preds)
|
||||
misaligned_loss = F.mse_loss(misaligned_preds, target_zeros)
|
||||
|
||||
# Add to total loss
|
||||
misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds))
|
||||
total_loss = total_loss + misaligned_loss
|
||||
output_dict['misaligned_loss'] = misaligned_loss.item()
|
||||
|
||||
# RA-BC weighted loss (if enabled)
|
||||
if self.config.enable_rabc:
|
||||
# Compute progress deltas (simplified: use consecutive frame differences)
|
||||
progress_deltas = progress_preds[:, 1:, 0] - progress_preds[:, :-1, 0]
|
||||
progress_deltas = progress_deltas.mean(dim=1) # Average over sequence
|
||||
|
||||
# Update running statistics
|
||||
self._update_rabc_stats(progress_deltas)
|
||||
|
||||
# Compute weights
|
||||
weights = self._compute_rabc_weights(progress_deltas)
|
||||
|
||||
# Apply weighted loss
|
||||
weighted_loss = (total_loss * weights.mean()).sum()
|
||||
total_loss = weighted_loss
|
||||
|
||||
# Add final total loss to output dict
|
||||
output_dict['total_loss'] = total_loss.item()
|
||||
|
||||
return total_loss, output_dict
|
||||
|
||||
|
||||
# Loss utilities
|
||||
def compute_stage_loss(
|
||||
stage_logits: torch.Tensor,
|
||||
target_stages: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute stage classification loss.
|
||||
|
||||
Args:
|
||||
stage_logits: Stage predictions (batch_size, num_frames, num_stages)
|
||||
target_stages: Target stage indices (batch_size, num_frames)
|
||||
|
||||
Returns:
|
||||
Cross-entropy loss
|
||||
"""
|
||||
batch_size, num_frames, num_stages = stage_logits.shape
|
||||
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
|
||||
_, _, num_stages = stage_logits.shape
|
||||
stage_logits_flat = stage_logits.reshape(-1, num_stages)
|
||||
target_stages_flat = target_stages.reshape(-1)
|
||||
|
||||
@@ -840,20 +690,7 @@ def compute_stage_loss(
|
||||
return loss
|
||||
|
||||
|
||||
def compute_progress_loss(
|
||||
progress_preds: torch.Tensor,
|
||||
target_progress: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute progress regression loss.
|
||||
|
||||
Args:
|
||||
progress_preds: Progress predictions (batch_size, num_frames, 1)
|
||||
target_progress: Target progress values (batch_size, num_frames, 1)
|
||||
|
||||
Returns:
|
||||
Mean squared error loss
|
||||
"""
|
||||
def compute_progress_loss(progress_preds: torch.Tensor, target_progress: torch.Tensor) -> torch.Tensor:
|
||||
loss = F.mse_loss(progress_preds, target_progress)
|
||||
return loss
|
||||
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.processor import (
|
||||
@@ -68,16 +70,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Compute temporal proportions from subtask annotations if available
|
||||
self.temporal_proportions = None
|
||||
self.subtask_names = None
|
||||
if dataset_meta is not None and config.use_subtask_annotations:
|
||||
if dataset_meta is not None:
|
||||
self._compute_temporal_proportions()
|
||||
|
||||
# Initialize encoders
|
||||
self._init_encoders()
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize CLIP and MiniLM encoders."""
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
|
||||
|
||||
device = torch.device(
|
||||
self.config.device if self.config.device
|
||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -116,13 +115,11 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
logging.info("No subtask annotations found in dataset")
|
||||
return
|
||||
|
||||
# Convert to pandas for easier processing
|
||||
import pandas as pd
|
||||
# Convert to pandas
|
||||
episodes_df = episodes.to_pandas()
|
||||
|
||||
# Collect all subtask names and compute average durations
|
||||
subtask_durations = {}
|
||||
subtask_counts = {}
|
||||
all_subtask_names = set()
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
@@ -178,44 +175,166 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
|
||||
|
||||
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
|
||||
"""Generate stage labels and refined progress targets from subtask annotations.
|
||||
def _to_numpy_array(self, x) -> np.ndarray:
|
||||
"""Convert input to a 1D numpy array."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
arr = x.cpu().numpy()
|
||||
else:
|
||||
arr = np.array(x)
|
||||
if arr.ndim == 0:
|
||||
arr = np.array([arr.item()])
|
||||
return arr
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
return ep_idx
|
||||
return 0 # Fallback
|
||||
|
||||
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||
"""Get episode indices for each frame index."""
|
||||
if episode_index is None:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
episode_indices = self._to_numpy_array(episode_index)
|
||||
|
||||
# If single episode but multiple frames, compute episode for each frame
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
return episode_indices
|
||||
|
||||
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor:
|
||||
"""Compute absolute frame indices for a sequence."""
|
||||
frame_gap = getattr(self.config, 'frame_gap', 1)
|
||||
|
||||
if frame_gap > 1:
|
||||
indices = [max(ep_start, frame_idx - (num_frames - 1 - i) * frame_gap) for i in range(num_frames)]
|
||||
return torch.tensor(indices)
|
||||
else:
|
||||
start_idx = max(ep_start, frame_idx - num_frames + 1)
|
||||
return torch.arange(start_idx, frame_idx + 1)
|
||||
|
||||
def _compute_episode_metadata(
|
||||
self,
|
||||
frame_indices: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
num_frames: int,
|
||||
is_batch: bool,
|
||||
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute episode metadata for all samples.
|
||||
|
||||
Args:
|
||||
frame_index: Current frame index or indices
|
||||
episode_index: Episode index
|
||||
video_features: Video features tensor to determine sequence length
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations
|
||||
Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths)
|
||||
"""
|
||||
if self.temporal_proportions is None or episode_index is None:
|
||||
return None, None
|
||||
absolute_indices_list = []
|
||||
remaining_lengths = []
|
||||
episode_lengths = []
|
||||
|
||||
# Convert to pandas to access annotations
|
||||
import pandas as pd
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Handle batch processing
|
||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
|
||||
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
|
||||
episode_lengths.append(ep_end - ep_start)
|
||||
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, num_frames)
|
||||
absolute_indices_list.append(abs_indices)
|
||||
remaining_lengths.append(ep_end - abs_indices[0].item())
|
||||
|
||||
if is_batch:
|
||||
# Process multiple samples - for now, return None
|
||||
# (batch processing of annotations is complex and not critical)
|
||||
return None, None
|
||||
|
||||
# Single sample processing
|
||||
if isinstance(episode_index, torch.Tensor):
|
||||
ep_idx = int(episode_index.item())
|
||||
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
|
||||
else:
|
||||
ep_idx = int(episode_index)
|
||||
return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0]
|
||||
|
||||
def _compute_stage_and_progress_for_frame(
|
||||
self,
|
||||
current_frame: int,
|
||||
subtask_names: list,
|
||||
subtask_start_frames: list,
|
||||
subtask_end_frames: list,
|
||||
) -> tuple[int, float]:
|
||||
"""Compute stage index and cumulative progress for a single frame.
|
||||
|
||||
if isinstance(frame_index, torch.Tensor):
|
||||
frame_idx = int(frame_index.item())
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
subtask_names: List of subtask names for this episode
|
||||
subtask_start_frames: List of subtask start frames
|
||||
subtask_end_frames: List of subtask end frames
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_idx, cumulative_progress)
|
||||
"""
|
||||
stage_idx = -1
|
||||
cumulative_progress = 0.0
|
||||
|
||||
# Find which subtask this frame belongs to
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if current_frame >= start_frame and current_frame <= end_frame:
|
||||
# Found the subtask
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
|
||||
|
||||
# Calculate within-subtask progress
|
||||
subtask_duration = end_frame - start_frame
|
||||
if subtask_duration > 0:
|
||||
within_subtask_progress = (current_frame - start_frame) / subtask_duration
|
||||
else:
|
||||
within_subtask_progress = 1.0
|
||||
|
||||
# Calculate cumulative progress from completed subtasks
|
||||
for k in range(j):
|
||||
prev_name = subtask_names[k]
|
||||
if prev_name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[prev_name]
|
||||
|
||||
# Add current subtask's partial progress
|
||||
if name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
|
||||
|
||||
return stage_idx, cumulative_progress
|
||||
|
||||
# No matching subtask found - estimate based on position
|
||||
if current_frame < subtask_start_frames[0]:
|
||||
return 0, 0.0
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
return len(self.subtask_names) - 1, 1.0
|
||||
else:
|
||||
frame_idx = int(frame_index)
|
||||
# Between subtasks - use previous subtask's end state
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
|
||||
# Sum up all completed subtasks
|
||||
for k in range(j + 1):
|
||||
prev_name = subtask_names[k]
|
||||
if prev_name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[prev_name]
|
||||
return stage_idx, cumulative_progress
|
||||
|
||||
# Get subtask annotations for this episode
|
||||
return 0, 0.0 # Fallback
|
||||
|
||||
def _compute_labels_for_sample(
|
||||
self,
|
||||
frame_idx: int,
|
||||
ep_idx: int,
|
||||
seq_len: int,
|
||||
episodes_df: pd.DataFrame,
|
||||
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
|
||||
"""Compute stage labels and progress targets for a single sample.
|
||||
|
||||
Args:
|
||||
frame_idx: The frame index for this sample
|
||||
ep_idx: The episode index
|
||||
seq_len: Number of frames in the sequence
|
||||
episodes_df: DataFrame with episode metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets) tensors with shapes (T,) and (T, 1),
|
||||
or (None, None) if no valid annotations
|
||||
"""
|
||||
# Check if episode has valid annotations
|
||||
if ep_idx >= len(episodes_df):
|
||||
return None, None
|
||||
|
||||
@@ -228,21 +347,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
|
||||
# Determine sequence length
|
||||
if video_features is not None and video_features.dim() > 0:
|
||||
seq_len = video_features.shape[0] if video_features.dim() == 2 else video_features.shape[1]
|
||||
else:
|
||||
seq_len = 1
|
||||
# Get frame gap for temporal sampling
|
||||
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
|
||||
|
||||
# Generate labels for each frame in the sequence
|
||||
stage_labels = []
|
||||
progress_targets = []
|
||||
|
||||
# Get frame gap for temporal sampling
|
||||
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
|
||||
|
||||
for i in range(seq_len):
|
||||
# Calculate actual frame index for this position in sequence
|
||||
if frame_gap > 1:
|
||||
@@ -251,326 +363,147 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
else:
|
||||
current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start)
|
||||
|
||||
# Find which subtask this frame belongs to
|
||||
stage_idx = -1
|
||||
within_subtask_progress = 0.0
|
||||
cumulative_progress = 0.0
|
||||
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if current_frame >= start_frame and current_frame <= end_frame:
|
||||
# Found the subtask
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
|
||||
|
||||
# Calculate within-subtask progress
|
||||
subtask_duration = end_frame - start_frame
|
||||
if subtask_duration > 0:
|
||||
within_subtask_progress = (current_frame - start_frame) / subtask_duration
|
||||
else:
|
||||
within_subtask_progress = 1.0
|
||||
|
||||
# Calculate cumulative progress
|
||||
for k in range(j):
|
||||
prev_name = subtask_names[k]
|
||||
if prev_name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[prev_name]
|
||||
|
||||
# Add current subtask's partial progress
|
||||
if name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
|
||||
|
||||
break
|
||||
|
||||
# If no matching subtask found, estimate based on position
|
||||
if stage_idx == -1:
|
||||
# Estimate stage based on frame position
|
||||
if current_frame < subtask_start_frames[0]:
|
||||
stage_idx = 0
|
||||
cumulative_progress = 0.0
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
stage_idx = len(self.subtask_names) - 1
|
||||
cumulative_progress = 1.0
|
||||
else:
|
||||
# Between subtasks - use previous subtask's end state
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
|
||||
# Sum up all previous subtasks
|
||||
for k in range(j + 1):
|
||||
prev_name = subtask_names[k]
|
||||
if prev_name in self.temporal_proportions:
|
||||
cumulative_progress += self.temporal_proportions[prev_name]
|
||||
break
|
||||
stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame(
|
||||
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
|
||||
)
|
||||
|
||||
stage_labels.append(stage_idx)
|
||||
progress_targets.append(cumulative_progress)
|
||||
|
||||
# Convert to tensors
|
||||
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
|
||||
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) # Add channel dim
|
||||
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
|
||||
|
||||
return stage_labels, progress_targets
|
||||
|
||||
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
|
||||
"""Generate stage labels and refined progress targets from subtask annotations.
|
||||
|
||||
Args:
|
||||
frame_index: Current frame index or tensor of indices
|
||||
episode_index: Episode index or tensor of indices
|
||||
video_features: Video features tensor to determine sequence length
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations.
|
||||
"""
|
||||
if self.temporal_proportions is None or episode_index is None:
|
||||
return None, None
|
||||
|
||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||
|
||||
# Normalize inputs to numpy arrays
|
||||
frame_indices = self._to_numpy_array(frame_index)
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine sequence length
|
||||
if video_features is not None and video_features.dim() >= 2:
|
||||
seq_len = video_features.shape[1] if is_batch else video_features.shape[0]
|
||||
else:
|
||||
seq_len = 1
|
||||
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Process all samples
|
||||
all_stage_labels = []
|
||||
all_progress_targets = []
|
||||
|
||||
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
|
||||
result = self._compute_labels_for_sample(int(frame_idx), int(ep_idx), seq_len, episodes_df)
|
||||
|
||||
if result[0] is None:
|
||||
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
|
||||
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
|
||||
else:
|
||||
all_stage_labels.append(result[0])
|
||||
all_progress_targets.append(result[1])
|
||||
|
||||
if is_batch:
|
||||
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
|
||||
return all_stage_labels[0], all_progress_targets[0]
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Encode images, text, and normalize states in the transition."""
|
||||
from lerobot.processor.core import TransitionKey
|
||||
|
||||
self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
||||
new_transition = self._current_transition
|
||||
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
||||
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or not isinstance(observation, dict):
|
||||
return new_transition
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("Observation must be a dictionary")
|
||||
|
||||
# Extract and encode images
|
||||
batch_size = 1
|
||||
if self.image_key in observation:
|
||||
image = observation[self.image_key]
|
||||
|
||||
# Handle different image formats
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
|
||||
# Encode images
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation['video_features'] = video_features
|
||||
|
||||
# Get batch size from encoded features
|
||||
batch_size = video_features.shape[0]
|
||||
# 1. Encode images with CLIP
|
||||
image = observation.get(self.image_key)
|
||||
if image is None:
|
||||
raise ValueError(f"Image not found in observation for key: {self.image_key}")
|
||||
|
||||
# Extract and normalize joint states
|
||||
if self.config.use_joint_state:
|
||||
# Look for "state" or "observation.state" in observation
|
||||
state_key = None
|
||||
state_data = None
|
||||
|
||||
if "state" in observation:
|
||||
state_key = "state"
|
||||
state_data = observation["state"]
|
||||
elif "observation.state" in observation:
|
||||
state_key = "observation.state"
|
||||
state_data = observation["observation.state"]
|
||||
|
||||
if state_data is not None:
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_data = state_data.cpu().numpy()
|
||||
|
||||
# Normalize if stats available
|
||||
if self.dataset_stats and state_key in self.dataset_stats:
|
||||
mean = self.dataset_stats[state_key]['mean']
|
||||
std = self.dataset_stats[state_key]['std']
|
||||
state_data = (state_data - mean) / (std + 1e-8)
|
||||
|
||||
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
|
||||
else:
|
||||
# Create dummy state features if not found
|
||||
if 'video_features' in observation:
|
||||
num_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1]
|
||||
observation['state_features'] = torch.zeros(batch_size, num_frames, self.config.state_dim)
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation['video_features'] = video_features
|
||||
|
||||
# Get task descriptions
|
||||
task_descriptions = None
|
||||
if 'task' in new_transition:
|
||||
task_descriptions = new_transition['task']
|
||||
# 2. Extract and normalize joint states
|
||||
state_data = observation.get("state") or observation.get("observation.state")
|
||||
if state_data is None:
|
||||
raise ValueError("State data not found in observation (expected 'state' or 'observation.state')")
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_data = state_data.cpu().numpy()
|
||||
|
||||
state_key = "state" if "state" in observation else "observation.state"
|
||||
if self.dataset_stats and state_key in self.dataset_stats:
|
||||
mean = self.dataset_stats[state_key]['mean']
|
||||
std = self.dataset_stats[state_key]['std']
|
||||
state_data = (state_data - mean) / (std + 1e-8)
|
||||
|
||||
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
# 3. Encode text with MiniLM
|
||||
batch_size = video_features.shape[0]
|
||||
task_descriptions = new_transition.get('task')
|
||||
if task_descriptions is not None:
|
||||
if isinstance(task_descriptions, str):
|
||||
task_descriptions = [task_descriptions] * batch_size
|
||||
|
||||
# Encode text
|
||||
if task_descriptions is not None:
|
||||
text_features = self._encode_text_batch_list(task_descriptions)
|
||||
observation['text_features'] = self._encode_text_batch_list(task_descriptions)
|
||||
else:
|
||||
text_features = self._encode_text_batch(self.task_description, batch_size)
|
||||
observation['text_features'] = self._encode_text_batch(self.task_description, batch_size)
|
||||
|
||||
observation['text_features'] = text_features
|
||||
# 4. Extract frame/episode indices from complementary data
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if not isinstance(comp_data, dict):
|
||||
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
|
||||
|
||||
# Compute episode metadata for progress normalization
|
||||
# Note: Processor runs BEFORE batching, so we need to extract from raw dataset structure
|
||||
# The dataset provides episode_index and index in the raw item
|
||||
frame_index = comp_data.get('index')
|
||||
episode_index = comp_data.get('episode_index')
|
||||
|
||||
# Extract index and episode_index from COMPLEMENTARY_DATA
|
||||
episode_index = None
|
||||
frame_index = None
|
||||
if frame_index is None:
|
||||
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
|
||||
if episode_index is None:
|
||||
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||
|
||||
# Primary location: COMPLEMENTARY_DATA (confirmed from debug logs)
|
||||
if TransitionKey.COMPLEMENTARY_DATA in new_transition:
|
||||
comp_data = new_transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
if isinstance(comp_data, dict):
|
||||
frame_index = comp_data.get('index')
|
||||
episode_index = comp_data.get('episode_index')
|
||||
|
||||
# Fallback: check other locations
|
||||
if frame_index is None and TransitionKey.OBSERVATION in new_transition:
|
||||
obs = new_transition[TransitionKey.OBSERVATION]
|
||||
if isinstance(obs, dict):
|
||||
frame_index = obs.get('index')
|
||||
if episode_index is None:
|
||||
episode_index = obs.get('episode_index')
|
||||
|
||||
# If we have frame_index but no episode_index, compute it from episode boundaries
|
||||
if frame_index is not None and episode_index is None and self.dataset_meta is not None:
|
||||
# Convert to int if needed
|
||||
if isinstance(frame_index, torch.Tensor):
|
||||
frame_idx = frame_index.item() if frame_index.numel() == 1 else frame_index[0].item()
|
||||
else:
|
||||
frame_idx = int(frame_index)
|
||||
|
||||
# Search through episodes to find which one this frame belongs to
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
episode_index = ep_idx
|
||||
break
|
||||
|
||||
if self.dataset_meta is not None and frame_index is not None:
|
||||
# Handle batch processing
|
||||
# 5. Compute episode metadata if dataset_meta is available
|
||||
if self.dataset_meta is not None:
|
||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||
frame_indices = self._to_numpy_array(frame_index)
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
if is_batch:
|
||||
# Batch case: process multiple samples at once
|
||||
batch_size = frame_index.shape[0]
|
||||
frame_indices = frame_index.cpu().numpy() if isinstance(frame_index, torch.Tensor) else np.array(frame_index)
|
||||
|
||||
# Ensure at least 1D
|
||||
if frame_indices.ndim == 0:
|
||||
frame_indices = np.array([frame_indices.item()])
|
||||
|
||||
# Compute episode_index for each frame if not provided
|
||||
if episode_index is None:
|
||||
episode_indices = []
|
||||
for frame_idx in frame_indices:
|
||||
frame_idx = int(frame_idx)
|
||||
# Search through episodes
|
||||
found = False
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
episode_indices.append(ep_idx)
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
episode_indices.append(0) # Fallback
|
||||
episode_indices = np.array(episode_indices)
|
||||
else:
|
||||
episode_indices = episode_index.cpu().numpy() if isinstance(episode_index, torch.Tensor) else np.array(episode_index)
|
||||
# Ensure at least 1D
|
||||
if episode_indices.ndim == 0:
|
||||
episode_indices = np.array([episode_indices.item()])
|
||||
|
||||
# CRITICAL FIX: If we have a single episode_index but multiple frame_indices,
|
||||
# compute the correct episode for each frame (they might be from different episodes)
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
episode_indices = []
|
||||
for frame_idx in frame_indices:
|
||||
frame_idx = int(frame_idx)
|
||||
# Search through episodes
|
||||
found = False
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
episode_indices.append(ep_idx)
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
episode_indices.append(0) # Fallback
|
||||
episode_indices = np.array(episode_indices)
|
||||
|
||||
# Compute metadata for each sample in batch
|
||||
absolute_indices_list = []
|
||||
remaining_lengths = []
|
||||
episode_lengths = []
|
||||
|
||||
# Convert to list for safe iteration
|
||||
episode_indices_list = episode_indices.tolist() if hasattr(episode_indices, 'tolist') else list(episode_indices)
|
||||
frame_indices_list = frame_indices.tolist() if hasattr(frame_indices, 'tolist') else list(frame_indices)
|
||||
|
||||
for i, (ep_idx, frame_idx) in enumerate(zip(episode_indices_list, frame_indices_list)):
|
||||
ep_idx = int(ep_idx)
|
||||
frame_idx = int(frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
episode_length = ep_end - ep_start
|
||||
episode_lengths.append(episode_length)
|
||||
|
||||
# Compute absolute indices for this sample
|
||||
if 'video_features' in observation and observation['video_features'].dim() > 1:
|
||||
num_loaded_frames = observation['video_features'].shape[1] # (batch, seq_len, features)
|
||||
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
|
||||
|
||||
if frame_gap > 1:
|
||||
absolute_indices = []
|
||||
for j in range(num_loaded_frames):
|
||||
offset = -(num_loaded_frames - 1 - j) * frame_gap
|
||||
idx = max(ep_start, frame_idx + offset)
|
||||
absolute_indices.append(idx)
|
||||
absolute_indices = torch.tensor(absolute_indices)
|
||||
else:
|
||||
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
|
||||
absolute_indices = torch.arange(start_idx, frame_idx + 1)
|
||||
|
||||
absolute_indices_list.append(absolute_indices)
|
||||
remaining_lengths.append(ep_end - absolute_indices[0].item())
|
||||
else:
|
||||
absolute_indices_list.append(torch.tensor([frame_idx]))
|
||||
remaining_lengths.append(ep_end - frame_idx)
|
||||
|
||||
observation['absolute_frame_indices'] = absolute_indices_list
|
||||
observation['remaining_length'] = torch.tensor(remaining_lengths)
|
||||
observation['episode_length'] = torch.tensor(episode_lengths)
|
||||
# Determine number of frames from video features
|
||||
if video_features.dim() >= 2:
|
||||
num_frames = video_features.shape[1] if is_batch else video_features.shape[0]
|
||||
else:
|
||||
# Single sample case
|
||||
if isinstance(frame_index, torch.Tensor):
|
||||
frame_idx = frame_index.item()
|
||||
else:
|
||||
frame_idx = int(frame_index)
|
||||
|
||||
# Get episode_index
|
||||
if episode_index is None:
|
||||
# Search through episodes
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
episode_index = ep_idx
|
||||
break
|
||||
if episode_index is None:
|
||||
episode_index = 0 # Fallback
|
||||
|
||||
ep_idx = int(episode_index) if not isinstance(episode_index, int) else episode_index
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Compute absolute indices
|
||||
if 'video_features' in observation and observation['video_features'].dim() > 0:
|
||||
num_loaded_frames = observation['video_features'].shape[0]
|
||||
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
|
||||
|
||||
if frame_gap > 1:
|
||||
absolute_indices = []
|
||||
for i in range(num_loaded_frames):
|
||||
offset = -(num_loaded_frames - 1 - i) * frame_gap
|
||||
idx = max(ep_start, frame_idx + offset)
|
||||
absolute_indices.append(idx)
|
||||
absolute_indices = torch.tensor(absolute_indices)
|
||||
else:
|
||||
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
|
||||
absolute_indices = torch.arange(start_idx, frame_idx + 1)
|
||||
|
||||
observation['absolute_frame_indices'] = absolute_indices
|
||||
observation['remaining_length'] = ep_end - absolute_indices[0].item()
|
||||
else:
|
||||
observation['absolute_frame_indices'] = torch.tensor([frame_idx])
|
||||
observation['remaining_length'] = ep_end - frame_idx
|
||||
|
||||
observation['episode_length'] = episode_length
|
||||
num_frames = 1
|
||||
|
||||
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
|
||||
frame_indices, episode_indices, num_frames, is_batch
|
||||
)
|
||||
observation['absolute_frame_indices'] = abs_indices
|
||||
observation['remaining_length'] = remaining
|
||||
observation['episode_length'] = ep_lengths
|
||||
|
||||
# Generate stage labels and refined progress from subtask annotations
|
||||
# 6. Generate stage labels and progress targets from subtask annotations
|
||||
if self.temporal_proportions is not None and self.dataset_meta is not None:
|
||||
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
|
||||
frame_index, episode_index, observation.get('video_features')
|
||||
frame_index, episode_index, video_features
|
||||
)
|
||||
if stage_labels is not None:
|
||||
observation['stage_labels'] = stage_labels
|
||||
@@ -714,11 +647,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
type=FeatureType.LANGUAGE,
|
||||
shape=(self.config.text_dim,)
|
||||
)
|
||||
if self.config.use_joint_state:
|
||||
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.config.num_frames, self.config.state_dim)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.config.num_frames, self.config.state_dim)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user