From c774818eda617d97dca11c281ef9b14685b7d3ec Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 25 Nov 2025 17:47:36 +0100 Subject: [PATCH] cleanup and refactor --- .../policies/sarm/configuration_sarm.py | 35 +- src/lerobot/policies/sarm/modeling_sarm.py | 419 ++++-------- src/lerobot/policies/sarm/processor_sarm.py | 610 ++++++++---------- 3 files changed, 407 insertions(+), 657 deletions(-) diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 624c5ffd0..e990fa83a 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -25,14 +25,7 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @PreTrainedConfig.register_subclass("sarm") @dataclass class SARMConfig(PreTrainedConfig): - """Configuration class for SARM (Stage-Aware Reward Modeling). - - SARM is a dual-head reward model that jointly predicts: - 1. High-level task stage (classification) - 2. Fine-grained progress within each stage (regression) - - It uses CLIP for visual encoding and supports joint state input. - """ + """Configuration class for SARM (Stage-Aware Reward Modeling)""" # Visual encoding parameters image_dim: int = 512 # CLIP embedding dimension @@ -40,21 +33,20 @@ class SARMConfig(PreTrainedConfig): frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second) # Text encoding parameters - text_dim: int = 384 # MiniLM embedding dimension + text_dim: int = 384 # Joint state parameters state_dim: int | None = None # Auto-detected from dataset if None - use_joint_state: bool = True # Whether to use joint state input # Architecture parameters - hidden_dim: int = 768 # Transformer hidden dimension - num_heads: int = 12 # Number of attention heads - num_layers: int = 8 # Number of transformer layers + hidden_dim: int = 768 + num_heads: int = 12 + num_layers: int = 8 num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available) subtask_names: list | None = None # List of subtask names (auto-populated from annotations) # Temporal parameters - max_length: int = 9 # Maximum video sequence length (should match num_frames) + max_length: int = num_frames # Maximum video sequence length (matches num_frames) use_temporal_sampler: bool = True # Always enable temporal sequence loading sampling_mode: str = "sarm" # Sampling mode: "sarm" or "rewind" @@ -65,24 +57,13 @@ class SARMConfig(PreTrainedConfig): dropout: float = 0.1 # Dropout rate stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations - # RA-BC (Reward-Aligned Behavior Cloning) parameters - enable_rabc: bool = False # Enable RA-BC weighted loss - rabc_kappa: float = 0.01 # Hard threshold for high-quality samples - rabc_epsilon: float = 1e-6 # Small constant to avoid division by zero - chunk_length: int = 25 # Action chunk length for computing progress deltas - - # Model loading pretrained_model_path: str | None = None - # Device settings device: str | None = None # Processor settings - image_key: str = "observation.images.top" # Key for images in dataset + image_key: str = "observation.images.top" # Key for image used from the dataset task_description: str = "perform the task" # Default task description - encode_on_the_fly: bool = True # Encode images/text during training - use_dataset_task: bool = True # Use task descriptions from dataset - use_subtask_annotations: bool = True # Use subtask annotations for stage-aware training if available # Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors input_features: dict = field(default_factory=lambda: { @@ -122,7 +103,7 @@ class SARMConfig(PreTrainedConfig): if self.sampling_mode not in ["sarm", "rewind", "custom"]: raise ValueError( - f"sampling_mode must be 'sarm', 'rewind', or 'custom', got {self.sampling_mode}" + f"sampling_mode must be 'sarm' or 'rewind', got {self.sampling_mode}" ) def get_optimizer_preset(self) -> AdamWConfig: diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 05c7d8fb7..9d2319fd3 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -63,33 +63,30 @@ class SARMTransformer(nn.Module): def __init__( self, - video_dim: int = 512, # CLIP dimension - text_dim: int = 384, # MiniLM dimension - state_dim: int = 14, # Joint state dimension + video_dim: int = 512, + text_dim: int = 384, + state_dim: int = 14, hidden_dim: int = 768, num_heads: int = 12, num_layers: int = 8, num_stages: int = 5, max_length: int = 9, - dropout: float = 0.1, - use_joint_state: bool = True + dropout: float = 0.1 ): super().__init__() self.hidden_dim = hidden_dim self.max_length = max_length self.num_stages = num_stages - self.use_joint_state = use_joint_state - # Project video, text, and state to common dimension + # Project video, text, and state to same dimension self.video_proj = nn.Linear(video_dim, hidden_dim) self.text_proj = nn.Linear(text_dim, hidden_dim) - if use_joint_state: - self.state_proj = nn.Linear(state_dim, hidden_dim) + self.state_proj = nn.Linear(state_dim, hidden_dim) # Position embedding only for the first frame self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim)) - # Transformer encoder (shared backbone) + # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, @@ -157,12 +154,10 @@ class SARMTransformer(nn.Module): # Project inputs to common dimension video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim] text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim] - - # Add joint state if provided - if self.use_joint_state and state_features is not None: - state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim] - # Fuse video and state features (simple addition) - video_embed = video_embed + state_embed + + state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim] + # Fuse video and state features (simple addition) + video_embed = video_embed + state_embed # Add positional embedding to first video frame video_embed[:, 0] += self.first_pos_embed @@ -274,16 +269,10 @@ class SARMRewardModel(PreTrainedPolicy): num_layers=config.num_layers, num_stages=config.num_stages, max_length=config.max_length, - dropout=config.dropout, - use_joint_state=config.use_joint_state + dropout=config.dropout ) self.sarm_transformer.to(self.device) - # RA-BC running statistics (for weighted loss) - if config.enable_rabc: - self.register_buffer("rabc_mean", torch.tensor(0.0)) - self.register_buffer("rabc_m2", torch.tensor(0.0)) - self.register_buffer("rabc_count", torch.tensor(0)) logging.info(f"SARM Reward Model initialized on {self.device}") @@ -474,40 +463,6 @@ class SARMRewardModel(PreTrainedPolicy): return rewards - def _update_rabc_stats(self, progress_deltas: torch.Tensor): - """Update running statistics for RA-BC using Welford's online algorithm.""" - if not self.config.enable_rabc: - return - - for delta in progress_deltas: - self.rabc_count += 1 - delta_val = delta.item() - delta_mean = delta_val - self.rabc_mean - self.rabc_mean += delta_mean / self.rabc_count - delta_m2 = delta_val - self.rabc_mean - self.rabc_m2 += delta_mean * delta_m2 - - def _compute_rabc_weights(self, progress_deltas: torch.Tensor) -> torch.Tensor: - """Compute RA-BC weights for progress deltas.""" - if not self.config.enable_rabc or self.rabc_count < 2: - return torch.ones_like(progress_deltas) - - # Get running statistics - mean = max(self.rabc_mean.item(), 0.0) # Clamp mean to non-negative - variance = self.rabc_m2 / (self.rabc_count - 1) - std = torch.sqrt(variance).item() - - # Compute soft weights - lower_bound = mean - 2 * std - upper_bound = mean + 2 * std - weights = (progress_deltas - lower_bound) / (4 * std + self.config.rabc_epsilon) - weights = torch.clamp(weights, 0.0, 1.0) - - # Apply hard threshold - high_quality_mask = progress_deltas > self.config.rabc_kappa - weights = torch.where(high_quality_mask, torch.ones_like(weights), weights) - - return weights def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False): """Load pretrained model weights from a checkpoint file.""" @@ -565,274 +520,169 @@ class SARMRewardModel(PreTrainedPolicy): """Required by PreTrainedPolicy but not used for SARM.""" raise NotImplementedError("SARM model does not select actions") + def _get_remaining_length(self, observation: dict, idx: int) -> float | None: + """Extract remaining length for a sample from observation metadata.""" + remaining_lengths = observation.get('remaining_length') + if remaining_lengths is None: + return None + if isinstance(remaining_lengths, torch.Tensor): + return remaining_lengths[idx].item() if remaining_lengths.dim() > 0 else remaining_lengths.item() + return remaining_lengths + + def _compute_progress_targets(self, remaining_length: float | None, seq_len: int) -> torch.Tensor: + """Compute progress targets based on remaining trajectory length.""" + if remaining_length is not None and remaining_length > 0: + return torch.arange(1, seq_len + 1, dtype=torch.float32, device=self.device) / remaining_length + else: + raise ValueError("Remaining length is None, but is required for progress targets") + + def _apply_rewind_augmentation( + self, + video: torch.Tensor, + progress: torch.Tensor, + state: torch.Tensor | None, + max_length: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Apply rewind augmentation: append 2-4 reversed frames (SARM paper).""" + num_reverse = random.randint(2, min(4, max_length - 1)) + + # Reverse and take frames (skip first which is last of original) + reversed_video = video.flip(0)[1:num_reverse + 1] + reversed_progress = progress.flip(0)[1:num_reverse + 1] + + # Concatenate and trim + video = torch.cat([video, reversed_video], dim=0)[:max_length] + progress = torch.cat([progress, reversed_progress], dim=0)[:max_length] + + if state is not None: + reversed_state = state.flip(0)[1:num_reverse + 1] + state = torch.cat([state, reversed_state], dim=0)[:max_length] + + return video, progress, state + + def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor: + """Pad or trim tensor to target length.""" + current_len = tensor.shape[0] + if current_len == target_len: + return tensor + if current_len < target_len: + padding = target_len - current_len + return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])]) + return tensor[:target_len] + def forward(self, batch): """ - Forward pass compatible with lerobot training pipeline. + Forward pass for SARM reward model training. Args: - batch: Dictionary containing observation with: - - 'video_features': Pre-encoded video features (B, T, 512) - - 'text_features': Pre-encoded text features (B, 384) - - 'state_features': Joint state features (B, T, state_dim) + batch: Dictionary with 'observation' containing: + - 'video_features': (B, T, 512) pre-encoded video features + - 'text_features': (B, 384) pre-encoded text features + - 'state_features': (B, T, state_dim) joint state features + - 'remaining_length': (B,) remaining trajectory lengths (optional) + - 'stage_labels': (B, T) stage labels (optional, from annotations) + - 'progress_targets': (B, T, 1) progress targets (optional, from annotations) Returns: - loss: Total training loss - output_dict: Dictionary of loss components for logging + Tuple of (total_loss, output_dict with loss components) """ - # Extract from observation dict observation = batch.get('observation', batch) + + # Extract required features video_features = observation['video_features'].to(self.device) text_features = observation['text_features'].to(self.device) - state_features = observation.get('state_features', None) + state_features = observation.get('state_features') if state_features is not None: state_features = state_features.to(self.device) - # Extract stage labels and progress targets if available (from subtask annotations) - stage_labels = observation.get('stage_labels', None) - if stage_labels is not None: - stage_labels = stage_labels.to(self.device) - - progress_targets_from_annotations = observation.get('progress_targets', None) - if progress_targets_from_annotations is not None: - progress_targets_from_annotations = progress_targets_from_annotations.to(self.device) - batch_size = video_features.shape[0] max_length = self.config.num_frames - # Handle both single frames and sequences + # Ensure 3D video features (B, T, D) if video_features.dim() == 2: - # Single frames: replicate to create pseudo-sequences - video_features = video_features.unsqueeze(1).repeat(1, max_length, 1) - + video_features = video_features.unsqueeze(1).expand(-1, max_length, -1) if state_features is not None and state_features.dim() == 2: - # Single state: replicate to match sequence length - state_features = state_features.unsqueeze(1).repeat(1, max_length, 1) - - # Apply rewind augmentation (following SARM paper: up to 4 reversed frames) - # Note: video_features are already sampled by dataset (9 frames with 30-frame gaps) - # We just need to compute progress targets and optionally apply rewind + state_features = state_features.unsqueeze(1).expand(-1, max_length, -1) + # Process each sample: compute progress targets and apply rewind augmentation processed_videos = [] processed_states = [] progress_targets = [] - # Extract episode metadata for correct progress normalization - absolute_frame_indices = observation.get('absolute_frame_indices', None) - episode_lengths = observation.get('episode_length', None) - remaining_lengths = observation.get('remaining_length', None) - for i in range(batch_size): - # Get metadata for this sample - current_absolute_indices = None - current_episode_length = None - current_remaining_length = None + remaining_length = self._get_remaining_length(observation, i) + progress = self._compute_progress_targets(remaining_length, max_length) - if absolute_frame_indices is not None: - if isinstance(absolute_frame_indices, list): - current_absolute_indices = absolute_frame_indices[i] - else: - current_absolute_indices = absolute_frame_indices + video = video_features[i] + state = state_features[i] if state_features is not None else None - if episode_lengths is not None: - if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0: - current_episode_length = episode_lengths[i].item() - else: - current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths - - if remaining_lengths is not None: - if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0: - current_remaining_length = remaining_lengths[i].item() - else: - current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths - - # Compute progress targets directly from metadata (frames already loaded by dataset) - # Progress = (position_in_sequence + 1) / remaining_trajectory_length - if current_remaining_length is not None and current_remaining_length > 0: - # Correct: relative progress from first loaded frame to episode end - progress_indices = torch.arange(1, max_length + 1, dtype=torch.float32, device=self.device) - progress = progress_indices / current_remaining_length - else: - # Fallback: linear progress (when metadata is not available) - logging.warning(f"Sample {i}: No remaining_length metadata, using linear progress fallback") - progress = torch.linspace(1.0/max_length, 1.0, max_length, device=self.device) - - # Apply rewind augmentation with 50% probability (following SARM paper) - # Paper specifies: "appending up to four frames from earlier timestamps with reversed order" + # Apply rewind augmentation with 50% probability (SARM paper) if random.random() < 0.5: - # Rewind: append 2-4 reversed frames, trim to max_length - num_reverse = random.randint(2, min(4, max_length - 1)) - - # Reverse video and progress - reversed_video = video_features[i].flip(0) - reversed_progress = progress.flip(0) - - # Take frames from reversed (skip first which is last of original) - reverse_frames = reversed_video[1:num_reverse+1] - reverse_progress = reversed_progress[1:num_reverse+1] - - # Concatenate forward + reversed - rewound_video = torch.cat([video_features[i], reverse_frames], dim=0) - rewound_progress = torch.cat([progress, reverse_progress], dim=0) - - # Trim to max_length - rewound_video = rewound_video[:max_length] - rewound_progress = rewound_progress[:max_length] - - processed_videos.append(rewound_video) - progress_targets.append(rewound_progress) - - # Process state features if available - if state_features is not None: - reversed_state = state_features[i].flip(0) - reverse_state_frames = reversed_state[1:num_reverse+1] - rewound_state = torch.cat([state_features[i], reverse_state_frames], dim=0) - rewound_state = rewound_state[:max_length] - processed_states.append(rewound_state) - else: - # Normal: use frames as-is with forward progress - processed_videos.append(video_features[i]) - progress_targets.append(progress) - - # Process state features if available - if state_features is not None: - processed_states.append(state_features[i]) + video, progress, state = self._apply_rewind_augmentation(video, progress, state, max_length) + + # Ensure correct sequence length + video = self._ensure_sequence_length(video, max_length) + progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1) + if state is not None: + state = self._ensure_sequence_length(state, max_length) + + processed_videos.append(video) + progress_targets.append(progress) + if state is not None: + processed_states.append(state) - # Ensure all sequences have the same length before stacking - # (sampling functions should return max_length, but double-check) - validated_videos = [] - validated_progress = [] - for i, (vid, prog) in enumerate(zip(processed_videos, progress_targets)): - if len(vid) != max_length: - logging.warning(f"Sample {i}: video length {len(vid)} != {max_length}, padding/trimming") - if len(vid) < max_length: - # Pad - padding = max_length - len(vid) - vid = torch.cat([vid, vid[-1:].repeat(padding, 1)]) - prog = torch.cat([prog, torch.full((padding,), prog[-1], device=prog.device)]) - else: - # Trim - vid = vid[:max_length] - prog = prog[:max_length] - validated_videos.append(vid) - validated_progress.append(prog) + # Stack into batches + processed_videos = torch.stack(processed_videos) + progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1) + processed_states = torch.stack(processed_states) if processed_states else None - # Stack processed features - processed_videos = torch.stack(validated_videos) - progress_targets = torch.stack(validated_progress) - - # Ensure progress_targets has the same shape as progress_preds - # progress_preds is (batch_size, num_frames, 1) - # progress_targets is (batch_size, num_frames) -> add last dimension - if progress_targets.dim() == 2: - progress_targets = progress_targets.unsqueeze(-1) # (batch_size, num_frames, 1) - - if state_features is not None and len(processed_states) > 0: - processed_states = torch.stack(processed_states) - else: - processed_states = None - - # Get predictions + # Get model predictions stage_logits, stage_probs, progress_preds = self.sarm_transformer( processed_videos, text_features, processed_states ) - # Use annotation-based progress targets if available, otherwise use computed ones - if progress_targets_from_annotations is not None and len(processed_videos) == 1: - # Use refined progress from subtask annotations (single sample case) - # Ensure shapes match - if progress_targets_from_annotations.shape != progress_preds.shape: - if progress_targets_from_annotations.dim() == 2: - progress_targets_from_annotations = progress_targets_from_annotations.unsqueeze(0) - progress_targets = progress_targets_from_annotations + # Use annotation-based progress targets + progress_from_annotations = observation.get('progress_targets') + if progress_from_annotations is not None: + progress_from_annotations = progress_from_annotations.to(self.device) + if progress_from_annotations.dim() == 2: + progress_from_annotations = progress_from_annotations.unsqueeze(-1) + if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1: + progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1) + progress_targets = progress_from_annotations - # Compute progress loss using targets + # Compute progress loss progress_loss = F.mse_loss(progress_preds, progress_targets) + output_dict = {'progress_loss': progress_loss.item()} + total_loss = progress_loss - # Compute stage loss if labels are available - stage_loss = None - if stage_labels is not None and len(processed_videos) == 1: - # Ensure stage_labels matches the sequence length - if stage_labels.dim() == 1 and stage_logits.dim() == 3: - # stage_labels: (seq_len,) -> need to expand to (batch, seq_len) - stage_labels = stage_labels.unsqueeze(0).expand(stage_logits.shape[0], -1) - elif stage_labels.shape[0] != stage_logits.shape[0]: - # Single label for batch - expand - stage_labels = stage_labels.expand(stage_logits.shape[0], stage_logits.shape[1]) - - # Compute cross-entropy loss for stage classification + # Compute stage loss if labels available + stage_labels = observation.get('stage_labels') + if stage_labels is not None: + stage_labels = stage_labels.to(self.device) + if stage_labels.dim() == 1: + stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1) stage_loss = compute_stage_loss(stage_logits, stage_labels) - - # Combine losses - if stage_loss is not None: - total_loss = progress_loss + self.config.stage_loss_weight * stage_loss - output_dict = { - 'progress_loss': progress_loss.item(), - 'stage_loss': stage_loss.item(), - } + total_loss = total_loss + self.config.stage_loss_weight * stage_loss + output_dict['stage_loss'] = stage_loss.item() else: - total_loss = progress_loss - output_dict = { - 'progress_loss': progress_loss.item(), - } + raise ValueError("Stage labels are None, but are required for stage loss") - # Compute misaligned loss (following SARM paper and ReWiND) - # "To improve video-language alignment, task descriptions are occasionally perturbed" - if random.random() < 0.2: # 20% probability (matching ReWiND) - # Create misaligned pairs by shuffling text features + # Misaligned loss: 20% probability (SARM paper - improve video-language alignment) + if random.random() < 0.2: shuffle_idx = torch.randperm(batch_size, device=self.device) - misaligned_texts = text_features[shuffle_idx] - - # Get predictions for misaligned pairs (should predict zero progress) _, _, misaligned_preds = self.sarm_transformer( - processed_videos, misaligned_texts, processed_states + processed_videos, text_features[shuffle_idx], processed_states ) - - # Target is zero progress for misaligned pairs - target_zeros = torch.zeros_like(misaligned_preds) - misaligned_loss = F.mse_loss(misaligned_preds, target_zeros) - - # Add to total loss + misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds)) total_loss = total_loss + misaligned_loss output_dict['misaligned_loss'] = misaligned_loss.item() - # RA-BC weighted loss (if enabled) - if self.config.enable_rabc: - # Compute progress deltas (simplified: use consecutive frame differences) - progress_deltas = progress_preds[:, 1:, 0] - progress_preds[:, :-1, 0] - progress_deltas = progress_deltas.mean(dim=1) # Average over sequence - - # Update running statistics - self._update_rabc_stats(progress_deltas) - - # Compute weights - weights = self._compute_rabc_weights(progress_deltas) - - # Apply weighted loss - weighted_loss = (total_loss * weights.mean()).sum() - total_loss = weighted_loss - - # Add final total loss to output dict output_dict['total_loss'] = total_loss.item() - return total_loss, output_dict - -# Loss utilities -def compute_stage_loss( - stage_logits: torch.Tensor, - target_stages: torch.Tensor -) -> torch.Tensor: - """ - Compute stage classification loss. - - Args: - stage_logits: Stage predictions (batch_size, num_frames, num_stages) - target_stages: Target stage indices (batch_size, num_frames) - - Returns: - Cross-entropy loss - """ - batch_size, num_frames, num_stages = stage_logits.shape +def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor: + _, _, num_stages = stage_logits.shape stage_logits_flat = stage_logits.reshape(-1, num_stages) target_stages_flat = target_stages.reshape(-1) @@ -840,20 +690,7 @@ def compute_stage_loss( return loss -def compute_progress_loss( - progress_preds: torch.Tensor, - target_progress: torch.Tensor -) -> torch.Tensor: - """ - Compute progress regression loss. - - Args: - progress_preds: Progress predictions (batch_size, num_frames, 1) - target_progress: Target progress values (batch_size, num_frames, 1) - - Returns: - Mean squared error loss - """ +def compute_progress_loss(progress_preds: torch.Tensor, target_progress: torch.Tensor) -> torch.Tensor: loss = F.mse_loss(progress_preds, target_progress) return loss diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 74ed5fafb..d5bb1f3fc 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -15,10 +15,12 @@ # limitations under the License. import logging -from typing import Dict, Any, List, Optional +from typing import Any import numpy as np import torch from PIL import Image +import pandas as pd +from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.processor import ( @@ -68,16 +70,13 @@ class SARMEncodingProcessorStep(ProcessorStep): # Compute temporal proportions from subtask annotations if available self.temporal_proportions = None self.subtask_names = None - if dataset_meta is not None and config.use_subtask_annotations: + if dataset_meta is not None: self._compute_temporal_proportions() - # Initialize encoders self._init_encoders() def _init_encoders(self): """Initialize CLIP and MiniLM encoders.""" - from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor - device = torch.device( self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu" @@ -116,13 +115,11 @@ class SARMEncodingProcessorStep(ProcessorStep): logging.info("No subtask annotations found in dataset") return - # Convert to pandas for easier processing - import pandas as pd + # Convert to pandas episodes_df = episodes.to_pandas() # Collect all subtask names and compute average durations subtask_durations = {} - subtask_counts = {} all_subtask_names = set() for ep_idx in episodes_df.index: @@ -178,44 +175,166 @@ class SARMEncodingProcessorStep(ProcessorStep): logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}") - def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features): - """Generate stage labels and refined progress targets from subtask annotations. + def _to_numpy_array(self, x) -> np.ndarray: + """Convert input to a 1D numpy array.""" + if isinstance(x, torch.Tensor): + arr = x.cpu().numpy() + else: + arr = np.array(x) + if arr.ndim == 0: + arr = np.array([arr.item()]) + return arr + + def _find_episode_for_frame(self, frame_idx: int) -> int: + """Find the episode index for a given frame index.""" + for ep_idx in range(len(self.dataset_meta.episodes)): + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + if ep_start <= frame_idx < ep_end: + return ep_idx + return 0 # Fallback + + def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray: + """Get episode indices for each frame index.""" + if episode_index is None: + return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) + + episode_indices = self._to_numpy_array(episode_index) + + # If single episode but multiple frames, compute episode for each frame + if len(episode_indices) == 1 and len(frame_indices) > 1: + return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) + + return episode_indices + + def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor: + """Compute absolute frame indices for a sequence.""" + frame_gap = getattr(self.config, 'frame_gap', 1) + + if frame_gap > 1: + indices = [max(ep_start, frame_idx - (num_frames - 1 - i) * frame_gap) for i in range(num_frames)] + return torch.tensor(indices) + else: + start_idx = max(ep_start, frame_idx - num_frames + 1) + return torch.arange(start_idx, frame_idx + 1) + + def _compute_episode_metadata( + self, + frame_indices: np.ndarray, + episode_indices: np.ndarray, + num_frames: int, + is_batch: bool, + ) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute episode metadata for all samples. - Args: - frame_index: Current frame index or indices - episode_index: Episode index - video_features: Video features tensor to determine sequence length - Returns: - Tuple of (stage_labels, progress_targets) or (None, None) if no annotations + Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths) """ - if self.temporal_proportions is None or episode_index is None: - return None, None + absolute_indices_list = [] + remaining_lengths = [] + episode_lengths = [] - # Convert to pandas to access annotations - import pandas as pd - episodes_df = self.dataset_meta.episodes.to_pandas() - - # Handle batch processing - is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 + for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()): + ep_idx, frame_idx = int(ep_idx), int(frame_idx) + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + + episode_lengths.append(ep_end - ep_start) + abs_indices = self._compute_absolute_indices(frame_idx, ep_start, num_frames) + absolute_indices_list.append(abs_indices) + remaining_lengths.append(ep_end - abs_indices[0].item()) if is_batch: - # Process multiple samples - for now, return None - # (batch processing of annotations is complex and not critical) - return None, None - - # Single sample processing - if isinstance(episode_index, torch.Tensor): - ep_idx = int(episode_index.item()) + return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths) else: - ep_idx = int(episode_index) + return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0] + + def _compute_stage_and_progress_for_frame( + self, + current_frame: int, + subtask_names: list, + subtask_start_frames: list, + subtask_end_frames: list, + ) -> tuple[int, float]: + """Compute stage index and cumulative progress for a single frame. - if isinstance(frame_index, torch.Tensor): - frame_idx = int(frame_index.item()) + Args: + current_frame: Frame index relative to episode start + subtask_names: List of subtask names for this episode + subtask_start_frames: List of subtask start frames + subtask_end_frames: List of subtask end frames + + Returns: + Tuple of (stage_idx, cumulative_progress) + """ + stage_idx = -1 + cumulative_progress = 0.0 + + # Find which subtask this frame belongs to + for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)): + if current_frame >= start_frame and current_frame <= end_frame: + # Found the subtask + stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0 + + # Calculate within-subtask progress + subtask_duration = end_frame - start_frame + if subtask_duration > 0: + within_subtask_progress = (current_frame - start_frame) / subtask_duration + else: + within_subtask_progress = 1.0 + + # Calculate cumulative progress from completed subtasks + for k in range(j): + prev_name = subtask_names[k] + if prev_name in self.temporal_proportions: + cumulative_progress += self.temporal_proportions[prev_name] + + # Add current subtask's partial progress + if name in self.temporal_proportions: + cumulative_progress += self.temporal_proportions[name] * within_subtask_progress + + return stage_idx, cumulative_progress + + # No matching subtask found - estimate based on position + if current_frame < subtask_start_frames[0]: + return 0, 0.0 + elif current_frame > subtask_end_frames[-1]: + return len(self.subtask_names) - 1, 1.0 else: - frame_idx = int(frame_index) + # Between subtasks - use previous subtask's end state + for j in range(len(subtask_names) - 1): + if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]: + name = subtask_names[j] + stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j + # Sum up all completed subtasks + for k in range(j + 1): + prev_name = subtask_names[k] + if prev_name in self.temporal_proportions: + cumulative_progress += self.temporal_proportions[prev_name] + return stage_idx, cumulative_progress - # Get subtask annotations for this episode + return 0, 0.0 # Fallback + + def _compute_labels_for_sample( + self, + frame_idx: int, + ep_idx: int, + seq_len: int, + episodes_df: pd.DataFrame, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]: + """Compute stage labels and progress targets for a single sample. + + Args: + frame_idx: The frame index for this sample + ep_idx: The episode index + seq_len: Number of frames in the sequence + episodes_df: DataFrame with episode metadata + + Returns: + Tuple of (stage_labels, progress_targets) tensors with shapes (T,) and (T, 1), + or (None, None) if no valid annotations + """ + # Check if episode has valid annotations if ep_idx >= len(episodes_df): return None, None @@ -228,21 +347,14 @@ class SARMEncodingProcessorStep(ProcessorStep): # Get episode boundaries ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - # Determine sequence length - if video_features is not None and video_features.dim() > 0: - seq_len = video_features.shape[0] if video_features.dim() == 2 else video_features.shape[1] - else: - seq_len = 1 + # Get frame gap for temporal sampling + frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 # Generate labels for each frame in the sequence stage_labels = [] progress_targets = [] - # Get frame gap for temporal sampling - frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 - for i in range(seq_len): # Calculate actual frame index for this position in sequence if frame_gap > 1: @@ -251,326 +363,147 @@ class SARMEncodingProcessorStep(ProcessorStep): else: current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start) - # Find which subtask this frame belongs to - stage_idx = -1 - within_subtask_progress = 0.0 - cumulative_progress = 0.0 - - for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)): - if current_frame >= start_frame and current_frame <= end_frame: - # Found the subtask - stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0 - - # Calculate within-subtask progress - subtask_duration = end_frame - start_frame - if subtask_duration > 0: - within_subtask_progress = (current_frame - start_frame) / subtask_duration - else: - within_subtask_progress = 1.0 - - # Calculate cumulative progress - for k in range(j): - prev_name = subtask_names[k] - if prev_name in self.temporal_proportions: - cumulative_progress += self.temporal_proportions[prev_name] - - # Add current subtask's partial progress - if name in self.temporal_proportions: - cumulative_progress += self.temporal_proportions[name] * within_subtask_progress - - break - - # If no matching subtask found, estimate based on position - if stage_idx == -1: - # Estimate stage based on frame position - if current_frame < subtask_start_frames[0]: - stage_idx = 0 - cumulative_progress = 0.0 - elif current_frame > subtask_end_frames[-1]: - stage_idx = len(self.subtask_names) - 1 - cumulative_progress = 1.0 - else: - # Between subtasks - use previous subtask's end state - for j in range(len(subtask_names) - 1): - if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]: - name = subtask_names[j] - stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j - # Sum up all previous subtasks - for k in range(j + 1): - prev_name = subtask_names[k] - if prev_name in self.temporal_proportions: - cumulative_progress += self.temporal_proportions[prev_name] - break + stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame( + current_frame, subtask_names, subtask_start_frames, subtask_end_frames + ) stage_labels.append(stage_idx) progress_targets.append(cumulative_progress) # Convert to tensors stage_labels = torch.tensor(stage_labels, dtype=torch.long) - progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) # Add channel dim + progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) return stage_labels, progress_targets + def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features): + """Generate stage labels and refined progress targets from subtask annotations. + + Args: + frame_index: Current frame index or tensor of indices + episode_index: Episode index or tensor of indices + video_features: Video features tensor to determine sequence length + + Returns: + Tuple of (stage_labels, progress_targets) or (None, None) if no annotations. + """ + if self.temporal_proportions is None or episode_index is None: + return None, None + + is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 + + # Normalize inputs to numpy arrays + frame_indices = self._to_numpy_array(frame_index) + episode_indices = self._get_episode_indices(frame_indices, episode_index) + + # Determine sequence length + if video_features is not None and video_features.dim() >= 2: + seq_len = video_features.shape[1] if is_batch else video_features.shape[0] + else: + seq_len = 1 + + episodes_df = self.dataset_meta.episodes.to_pandas() + + # Process all samples + all_stage_labels = [] + all_progress_targets = [] + + for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()): + result = self._compute_labels_for_sample(int(frame_idx), int(ep_idx), seq_len, episodes_df) + + if result[0] is None: + all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long)) + all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32)) + else: + all_stage_labels.append(result[0]) + all_progress_targets.append(result[1]) + + if is_batch: + return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0) + return all_stage_labels[0], all_progress_targets[0] + def __call__(self, transition: EnvTransition) -> EnvTransition: """Encode images, text, and normalize states in the transition.""" from lerobot.processor.core import TransitionKey - self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition) - new_transition = self._current_transition + new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition) observation = new_transition.get(TransitionKey.OBSERVATION) - if observation is None or not isinstance(observation, dict): - return new_transition + if not isinstance(observation, dict): + raise ValueError("Observation must be a dictionary") - # Extract and encode images - batch_size = 1 - if self.image_key in observation: - image = observation[self.image_key] - - # Handle different image formats - if isinstance(image, torch.Tensor): - image = image.cpu().numpy() - - # Encode images - video_features = self._encode_images_batch(image) - observation['video_features'] = video_features - - # Get batch size from encoded features - batch_size = video_features.shape[0] + # 1. Encode images with CLIP + image = observation.get(self.image_key) + if image is None: + raise ValueError(f"Image not found in observation for key: {self.image_key}") - # Extract and normalize joint states - if self.config.use_joint_state: - # Look for "state" or "observation.state" in observation - state_key = None - state_data = None - - if "state" in observation: - state_key = "state" - state_data = observation["state"] - elif "observation.state" in observation: - state_key = "observation.state" - state_data = observation["observation.state"] - - if state_data is not None: - if isinstance(state_data, torch.Tensor): - state_data = state_data.cpu().numpy() - - # Normalize if stats available - if self.dataset_stats and state_key in self.dataset_stats: - mean = self.dataset_stats[state_key]['mean'] - std = self.dataset_stats[state_key]['std'] - state_data = (state_data - mean) / (std + 1e-8) - - observation['state_features'] = torch.tensor(state_data, dtype=torch.float32) - else: - # Create dummy state features if not found - if 'video_features' in observation: - num_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1] - observation['state_features'] = torch.zeros(batch_size, num_frames, self.config.state_dim) + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + video_features = self._encode_images_batch(image) + observation['video_features'] = video_features - # Get task descriptions - task_descriptions = None - if 'task' in new_transition: - task_descriptions = new_transition['task'] + # 2. Extract and normalize joint states + state_data = observation.get("state") or observation.get("observation.state") + if state_data is None: + raise ValueError("State data not found in observation (expected 'state' or 'observation.state')") + + if isinstance(state_data, torch.Tensor): + state_data = state_data.cpu().numpy() + + state_key = "state" if "state" in observation else "observation.state" + if self.dataset_stats and state_key in self.dataset_stats: + mean = self.dataset_stats[state_key]['mean'] + std = self.dataset_stats[state_key]['std'] + state_data = (state_data - mean) / (std + 1e-8) + + observation['state_features'] = torch.tensor(state_data, dtype=torch.float32) + + # 3. Encode text with MiniLM + batch_size = video_features.shape[0] + task_descriptions = new_transition.get('task') + if task_descriptions is not None: if isinstance(task_descriptions, str): task_descriptions = [task_descriptions] * batch_size - - # Encode text - if task_descriptions is not None: - text_features = self._encode_text_batch_list(task_descriptions) + observation['text_features'] = self._encode_text_batch_list(task_descriptions) else: - text_features = self._encode_text_batch(self.task_description, batch_size) + observation['text_features'] = self._encode_text_batch(self.task_description, batch_size) - observation['text_features'] = text_features + # 4. Extract frame/episode indices from complementary data + comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if not isinstance(comp_data, dict): + raise ValueError("COMPLEMENTARY_DATA must be a dictionary") - # Compute episode metadata for progress normalization - # Note: Processor runs BEFORE batching, so we need to extract from raw dataset structure - # The dataset provides episode_index and index in the raw item + frame_index = comp_data.get('index') + episode_index = comp_data.get('episode_index') - # Extract index and episode_index from COMPLEMENTARY_DATA - episode_index = None - frame_index = None + if frame_index is None: + raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA") + if episode_index is None: + raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA") - # Primary location: COMPLEMENTARY_DATA (confirmed from debug logs) - if TransitionKey.COMPLEMENTARY_DATA in new_transition: - comp_data = new_transition[TransitionKey.COMPLEMENTARY_DATA] - if isinstance(comp_data, dict): - frame_index = comp_data.get('index') - episode_index = comp_data.get('episode_index') - - # Fallback: check other locations - if frame_index is None and TransitionKey.OBSERVATION in new_transition: - obs = new_transition[TransitionKey.OBSERVATION] - if isinstance(obs, dict): - frame_index = obs.get('index') - if episode_index is None: - episode_index = obs.get('episode_index') - - # If we have frame_index but no episode_index, compute it from episode boundaries - if frame_index is not None and episode_index is None and self.dataset_meta is not None: - # Convert to int if needed - if isinstance(frame_index, torch.Tensor): - frame_idx = frame_index.item() if frame_index.numel() == 1 else frame_index[0].item() - else: - frame_idx = int(frame_index) - - # Search through episodes to find which one this frame belongs to - for ep_idx in range(len(self.dataset_meta.episodes)): - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - if ep_start <= frame_idx < ep_end: - episode_index = ep_idx - break - - if self.dataset_meta is not None and frame_index is not None: - # Handle batch processing + # 5. Compute episode metadata if dataset_meta is available + if self.dataset_meta is not None: is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 + frame_indices = self._to_numpy_array(frame_index) + episode_indices = self._get_episode_indices(frame_indices, episode_index) - if is_batch: - # Batch case: process multiple samples at once - batch_size = frame_index.shape[0] - frame_indices = frame_index.cpu().numpy() if isinstance(frame_index, torch.Tensor) else np.array(frame_index) - - # Ensure at least 1D - if frame_indices.ndim == 0: - frame_indices = np.array([frame_indices.item()]) - - # Compute episode_index for each frame if not provided - if episode_index is None: - episode_indices = [] - for frame_idx in frame_indices: - frame_idx = int(frame_idx) - # Search through episodes - found = False - for ep_idx in range(len(self.dataset_meta.episodes)): - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - if ep_start <= frame_idx < ep_end: - episode_indices.append(ep_idx) - found = True - break - if not found: - episode_indices.append(0) # Fallback - episode_indices = np.array(episode_indices) - else: - episode_indices = episode_index.cpu().numpy() if isinstance(episode_index, torch.Tensor) else np.array(episode_index) - # Ensure at least 1D - if episode_indices.ndim == 0: - episode_indices = np.array([episode_indices.item()]) - - # CRITICAL FIX: If we have a single episode_index but multiple frame_indices, - # compute the correct episode for each frame (they might be from different episodes) - if len(episode_indices) == 1 and len(frame_indices) > 1: - episode_indices = [] - for frame_idx in frame_indices: - frame_idx = int(frame_idx) - # Search through episodes - found = False - for ep_idx in range(len(self.dataset_meta.episodes)): - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - if ep_start <= frame_idx < ep_end: - episode_indices.append(ep_idx) - found = True - break - if not found: - episode_indices.append(0) # Fallback - episode_indices = np.array(episode_indices) - - # Compute metadata for each sample in batch - absolute_indices_list = [] - remaining_lengths = [] - episode_lengths = [] - - # Convert to list for safe iteration - episode_indices_list = episode_indices.tolist() if hasattr(episode_indices, 'tolist') else list(episode_indices) - frame_indices_list = frame_indices.tolist() if hasattr(frame_indices, 'tolist') else list(frame_indices) - - for i, (ep_idx, frame_idx) in enumerate(zip(episode_indices_list, frame_indices_list)): - ep_idx = int(ep_idx) - frame_idx = int(frame_idx) - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - episode_length = ep_end - ep_start - episode_lengths.append(episode_length) - - # Compute absolute indices for this sample - if 'video_features' in observation and observation['video_features'].dim() > 1: - num_loaded_frames = observation['video_features'].shape[1] # (batch, seq_len, features) - frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 - - if frame_gap > 1: - absolute_indices = [] - for j in range(num_loaded_frames): - offset = -(num_loaded_frames - 1 - j) * frame_gap - idx = max(ep_start, frame_idx + offset) - absolute_indices.append(idx) - absolute_indices = torch.tensor(absolute_indices) - else: - start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) - absolute_indices = torch.arange(start_idx, frame_idx + 1) - - absolute_indices_list.append(absolute_indices) - remaining_lengths.append(ep_end - absolute_indices[0].item()) - else: - absolute_indices_list.append(torch.tensor([frame_idx])) - remaining_lengths.append(ep_end - frame_idx) - - observation['absolute_frame_indices'] = absolute_indices_list - observation['remaining_length'] = torch.tensor(remaining_lengths) - observation['episode_length'] = torch.tensor(episode_lengths) + # Determine number of frames from video features + if video_features.dim() >= 2: + num_frames = video_features.shape[1] if is_batch else video_features.shape[0] else: - # Single sample case - if isinstance(frame_index, torch.Tensor): - frame_idx = frame_index.item() - else: - frame_idx = int(frame_index) - - # Get episode_index - if episode_index is None: - # Search through episodes - for ep_idx in range(len(self.dataset_meta.episodes)): - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - if ep_start <= frame_idx < ep_end: - episode_index = ep_idx - break - if episode_index is None: - episode_index = 0 # Fallback - - ep_idx = int(episode_index) if not isinstance(episode_index, int) else episode_index - ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] - ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] - episode_length = ep_end - ep_start - - # Compute absolute indices - if 'video_features' in observation and observation['video_features'].dim() > 0: - num_loaded_frames = observation['video_features'].shape[0] - frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 - - if frame_gap > 1: - absolute_indices = [] - for i in range(num_loaded_frames): - offset = -(num_loaded_frames - 1 - i) * frame_gap - idx = max(ep_start, frame_idx + offset) - absolute_indices.append(idx) - absolute_indices = torch.tensor(absolute_indices) - else: - start_idx = max(ep_start, frame_idx - num_loaded_frames + 1) - absolute_indices = torch.arange(start_idx, frame_idx + 1) - - observation['absolute_frame_indices'] = absolute_indices - observation['remaining_length'] = ep_end - absolute_indices[0].item() - else: - observation['absolute_frame_indices'] = torch.tensor([frame_idx]) - observation['remaining_length'] = ep_end - frame_idx - - observation['episode_length'] = episode_length + num_frames = 1 + + abs_indices, remaining, ep_lengths = self._compute_episode_metadata( + frame_indices, episode_indices, num_frames, is_batch + ) + observation['absolute_frame_indices'] = abs_indices + observation['remaining_length'] = remaining + observation['episode_length'] = ep_lengths - # Generate stage labels and refined progress from subtask annotations + # 6. Generate stage labels and progress targets from subtask annotations if self.temporal_proportions is not None and self.dataset_meta is not None: stage_labels, progress_targets = self._generate_stage_and_progress_labels( - frame_index, episode_index, observation.get('video_features') + frame_index, episode_index, video_features ) if stage_labels is not None: observation['stage_labels'] = stage_labels @@ -714,11 +647,10 @@ class SARMEncodingProcessorStep(ProcessorStep): type=FeatureType.LANGUAGE, shape=(self.config.text_dim,) ) - if self.config.use_joint_state: - features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature( - type=FeatureType.STATE, - shape=(self.config.num_frames, self.config.state_dim) - ) + features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature( + type=FeatureType.STATE, + shape=(self.config.num_frames, self.config.state_dim) + ) return features