cleanup and refactor

This commit is contained in:
Pepijn
2025-11-25 17:47:36 +01:00
parent 3b31c2d9d3
commit c774818eda
3 changed files with 407 additions and 657 deletions
@@ -25,14 +25,7 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("sarm")
@dataclass
class SARMConfig(PreTrainedConfig):
"""Configuration class for SARM (Stage-Aware Reward Modeling).
SARM is a dual-head reward model that jointly predicts:
1. High-level task stage (classification)
2. Fine-grained progress within each stage (regression)
It uses CLIP for visual encoding and supports joint state input.
"""
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
# Visual encoding parameters
image_dim: int = 512 # CLIP embedding dimension
@@ -40,21 +33,20 @@ class SARMConfig(PreTrainedConfig):
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
# Text encoding parameters
text_dim: int = 384 # MiniLM embedding dimension
text_dim: int = 384
# Joint state parameters
state_dim: int | None = None # Auto-detected from dataset if None
use_joint_state: bool = True # Whether to use joint state input
# Architecture parameters
hidden_dim: int = 768 # Transformer hidden dimension
num_heads: int = 12 # Number of attention heads
num_layers: int = 8 # Number of transformer layers
hidden_dim: int = 768
num_heads: int = 12
num_layers: int = 8
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
# Temporal parameters
max_length: int = 9 # Maximum video sequence length (should match num_frames)
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading
sampling_mode: str = "sarm" # Sampling mode: "sarm" or "rewind"
@@ -65,24 +57,13 @@ class SARMConfig(PreTrainedConfig):
dropout: float = 0.1 # Dropout rate
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
# RA-BC (Reward-Aligned Behavior Cloning) parameters
enable_rabc: bool = False # Enable RA-BC weighted loss
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant to avoid division by zero
chunk_length: int = 25 # Action chunk length for computing progress deltas
# Model loading
pretrained_model_path: str | None = None
# Device settings
device: str | None = None
# Processor settings
image_key: str = "observation.images.top" # Key for images in dataset
image_key: str = "observation.images.top" # Key for image used from the dataset
task_description: str = "perform the task" # Default task description
encode_on_the_fly: bool = True # Encode images/text during training
use_dataset_task: bool = True # Use task descriptions from dataset
use_subtask_annotations: bool = True # Use subtask annotations for stage-aware training if available
# Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors
input_features: dict = field(default_factory=lambda: {
@@ -122,7 +103,7 @@ class SARMConfig(PreTrainedConfig):
if self.sampling_mode not in ["sarm", "rewind", "custom"]:
raise ValueError(
f"sampling_mode must be 'sarm', 'rewind', or 'custom', got {self.sampling_mode}"
f"sampling_mode must be 'sarm' or 'rewind', got {self.sampling_mode}"
)
def get_optimizer_preset(self) -> AdamWConfig:
+128 -291
View File
@@ -63,33 +63,30 @@ class SARMTransformer(nn.Module):
def __init__(
self,
video_dim: int = 512, # CLIP dimension
text_dim: int = 384, # MiniLM dimension
state_dim: int = 14, # Joint state dimension
video_dim: int = 512,
text_dim: int = 384,
state_dim: int = 14,
hidden_dim: int = 768,
num_heads: int = 12,
num_layers: int = 8,
num_stages: int = 5,
max_length: int = 9,
dropout: float = 0.1,
use_joint_state: bool = True
dropout: float = 0.1
):
super().__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
self.num_stages = num_stages
self.use_joint_state = use_joint_state
# Project video, text, and state to common dimension
# Project video, text, and state to same dimension
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
if use_joint_state:
self.state_proj = nn.Linear(state_dim, hidden_dim)
self.state_proj = nn.Linear(state_dim, hidden_dim)
# Position embedding only for the first frame
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
# Transformer encoder (shared backbone)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
@@ -157,12 +154,10 @@ class SARMTransformer(nn.Module):
# Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
# Add joint state if provided
if self.use_joint_state and state_features is not None:
state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features (simple addition)
video_embed = video_embed + state_embed
state_embed = self.state_proj(state_features) # [batch_size, seq_len, hidden_dim]
# Fuse video and state features (simple addition)
video_embed = video_embed + state_embed
# Add positional embedding to first video frame
video_embed[:, 0] += self.first_pos_embed
@@ -274,16 +269,10 @@ class SARMRewardModel(PreTrainedPolicy):
num_layers=config.num_layers,
num_stages=config.num_stages,
max_length=config.max_length,
dropout=config.dropout,
use_joint_state=config.use_joint_state
dropout=config.dropout
)
self.sarm_transformer.to(self.device)
# RA-BC running statistics (for weighted loss)
if config.enable_rabc:
self.register_buffer("rabc_mean", torch.tensor(0.0))
self.register_buffer("rabc_m2", torch.tensor(0.0))
self.register_buffer("rabc_count", torch.tensor(0))
logging.info(f"SARM Reward Model initialized on {self.device}")
@@ -474,40 +463,6 @@ class SARMRewardModel(PreTrainedPolicy):
return rewards
def _update_rabc_stats(self, progress_deltas: torch.Tensor):
"""Update running statistics for RA-BC using Welford's online algorithm."""
if not self.config.enable_rabc:
return
for delta in progress_deltas:
self.rabc_count += 1
delta_val = delta.item()
delta_mean = delta_val - self.rabc_mean
self.rabc_mean += delta_mean / self.rabc_count
delta_m2 = delta_val - self.rabc_mean
self.rabc_m2 += delta_mean * delta_m2
def _compute_rabc_weights(self, progress_deltas: torch.Tensor) -> torch.Tensor:
"""Compute RA-BC weights for progress deltas."""
if not self.config.enable_rabc or self.rabc_count < 2:
return torch.ones_like(progress_deltas)
# Get running statistics
mean = max(self.rabc_mean.item(), 0.0) # Clamp mean to non-negative
variance = self.rabc_m2 / (self.rabc_count - 1)
std = torch.sqrt(variance).item()
# Compute soft weights
lower_bound = mean - 2 * std
upper_bound = mean + 2 * std
weights = (progress_deltas - lower_bound) / (4 * std + self.config.rabc_epsilon)
weights = torch.clamp(weights, 0.0, 1.0)
# Apply hard threshold
high_quality_mask = progress_deltas > self.config.rabc_kappa
weights = torch.where(high_quality_mask, torch.ones_like(weights), weights)
return weights
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
"""Load pretrained model weights from a checkpoint file."""
@@ -565,274 +520,169 @@ class SARMRewardModel(PreTrainedPolicy):
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _get_remaining_length(self, observation: dict, idx: int) -> float | None:
"""Extract remaining length for a sample from observation metadata."""
remaining_lengths = observation.get('remaining_length')
if remaining_lengths is None:
return None
if isinstance(remaining_lengths, torch.Tensor):
return remaining_lengths[idx].item() if remaining_lengths.dim() > 0 else remaining_lengths.item()
return remaining_lengths
def _compute_progress_targets(self, remaining_length: float | None, seq_len: int) -> torch.Tensor:
"""Compute progress targets based on remaining trajectory length."""
if remaining_length is not None and remaining_length > 0:
return torch.arange(1, seq_len + 1, dtype=torch.float32, device=self.device) / remaining_length
else:
raise ValueError("Remaining length is None, but is required for progress targets")
def _apply_rewind_augmentation(
self,
video: torch.Tensor,
progress: torch.Tensor,
state: torch.Tensor | None,
max_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply rewind augmentation: append 2-4 reversed frames (SARM paper)."""
num_reverse = random.randint(2, min(4, max_length - 1))
# Reverse and take frames (skip first which is last of original)
reversed_video = video.flip(0)[1:num_reverse + 1]
reversed_progress = progress.flip(0)[1:num_reverse + 1]
# Concatenate and trim
video = torch.cat([video, reversed_video], dim=0)[:max_length]
progress = torch.cat([progress, reversed_progress], dim=0)[:max_length]
if state is not None:
reversed_state = state.flip(0)[1:num_reverse + 1]
state = torch.cat([state, reversed_state], dim=0)[:max_length]
return video, progress, state
def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor:
"""Pad or trim tensor to target length."""
current_len = tensor.shape[0]
if current_len == target_len:
return tensor
if current_len < target_len:
padding = target_len - current_len
return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])])
return tensor[:target_len]
def forward(self, batch):
"""
Forward pass compatible with lerobot training pipeline.
Forward pass for SARM reward model training.
Args:
batch: Dictionary containing observation with:
- 'video_features': Pre-encoded video features (B, T, 512)
- 'text_features': Pre-encoded text features (B, 384)
- 'state_features': Joint state features (B, T, state_dim)
batch: Dictionary with 'observation' containing:
- 'video_features': (B, T, 512) pre-encoded video features
- 'text_features': (B, 384) pre-encoded text features
- 'state_features': (B, T, state_dim) joint state features
- 'remaining_length': (B,) remaining trajectory lengths (optional)
- 'stage_labels': (B, T) stage labels (optional, from annotations)
- 'progress_targets': (B, T, 1) progress targets (optional, from annotations)
Returns:
loss: Total training loss
output_dict: Dictionary of loss components for logging
Tuple of (total_loss, output_dict with loss components)
"""
# Extract from observation dict
observation = batch.get('observation', batch)
# Extract required features
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
state_features = observation.get('state_features', None)
state_features = observation.get('state_features')
if state_features is not None:
state_features = state_features.to(self.device)
# Extract stage labels and progress targets if available (from subtask annotations)
stage_labels = observation.get('stage_labels', None)
if stage_labels is not None:
stage_labels = stage_labels.to(self.device)
progress_targets_from_annotations = observation.get('progress_targets', None)
if progress_targets_from_annotations is not None:
progress_targets_from_annotations = progress_targets_from_annotations.to(self.device)
batch_size = video_features.shape[0]
max_length = self.config.num_frames
# Handle both single frames and sequences
# Ensure 3D video features (B, T, D)
if video_features.dim() == 2:
# Single frames: replicate to create pseudo-sequences
video_features = video_features.unsqueeze(1).repeat(1, max_length, 1)
video_features = video_features.unsqueeze(1).expand(-1, max_length, -1)
if state_features is not None and state_features.dim() == 2:
# Single state: replicate to match sequence length
state_features = state_features.unsqueeze(1).repeat(1, max_length, 1)
# Apply rewind augmentation (following SARM paper: up to 4 reversed frames)
# Note: video_features are already sampled by dataset (9 frames with 30-frame gaps)
# We just need to compute progress targets and optionally apply rewind
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
# Process each sample: compute progress targets and apply rewind augmentation
processed_videos = []
processed_states = []
progress_targets = []
# Extract episode metadata for correct progress normalization
absolute_frame_indices = observation.get('absolute_frame_indices', None)
episode_lengths = observation.get('episode_length', None)
remaining_lengths = observation.get('remaining_length', None)
for i in range(batch_size):
# Get metadata for this sample
current_absolute_indices = None
current_episode_length = None
current_remaining_length = None
remaining_length = self._get_remaining_length(observation, i)
progress = self._compute_progress_targets(remaining_length, max_length)
if absolute_frame_indices is not None:
if isinstance(absolute_frame_indices, list):
current_absolute_indices = absolute_frame_indices[i]
else:
current_absolute_indices = absolute_frame_indices
video = video_features[i]
state = state_features[i] if state_features is not None else None
if episode_lengths is not None:
if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0:
current_episode_length = episode_lengths[i].item()
else:
current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths
if remaining_lengths is not None:
if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0:
current_remaining_length = remaining_lengths[i].item()
else:
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
# Compute progress targets directly from metadata (frames already loaded by dataset)
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
if current_remaining_length is not None and current_remaining_length > 0:
# Correct: relative progress from first loaded frame to episode end
progress_indices = torch.arange(1, max_length + 1, dtype=torch.float32, device=self.device)
progress = progress_indices / current_remaining_length
else:
# Fallback: linear progress (when metadata is not available)
logging.warning(f"Sample {i}: No remaining_length metadata, using linear progress fallback")
progress = torch.linspace(1.0/max_length, 1.0, max_length, device=self.device)
# Apply rewind augmentation with 50% probability (following SARM paper)
# Paper specifies: "appending up to four frames from earlier timestamps with reversed order"
# Apply rewind augmentation with 50% probability (SARM paper)
if random.random() < 0.5:
# Rewind: append 2-4 reversed frames, trim to max_length
num_reverse = random.randint(2, min(4, max_length - 1))
# Reverse video and progress
reversed_video = video_features[i].flip(0)
reversed_progress = progress.flip(0)
# Take frames from reversed (skip first which is last of original)
reverse_frames = reversed_video[1:num_reverse+1]
reverse_progress = reversed_progress[1:num_reverse+1]
# Concatenate forward + reversed
rewound_video = torch.cat([video_features[i], reverse_frames], dim=0)
rewound_progress = torch.cat([progress, reverse_progress], dim=0)
# Trim to max_length
rewound_video = rewound_video[:max_length]
rewound_progress = rewound_progress[:max_length]
processed_videos.append(rewound_video)
progress_targets.append(rewound_progress)
# Process state features if available
if state_features is not None:
reversed_state = state_features[i].flip(0)
reverse_state_frames = reversed_state[1:num_reverse+1]
rewound_state = torch.cat([state_features[i], reverse_state_frames], dim=0)
rewound_state = rewound_state[:max_length]
processed_states.append(rewound_state)
else:
# Normal: use frames as-is with forward progress
processed_videos.append(video_features[i])
progress_targets.append(progress)
# Process state features if available
if state_features is not None:
processed_states.append(state_features[i])
video, progress, state = self._apply_rewind_augmentation(video, progress, state, max_length)
# Ensure correct sequence length
video = self._ensure_sequence_length(video, max_length)
progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1)
if state is not None:
state = self._ensure_sequence_length(state, max_length)
processed_videos.append(video)
progress_targets.append(progress)
if state is not None:
processed_states.append(state)
# Ensure all sequences have the same length before stacking
# (sampling functions should return max_length, but double-check)
validated_videos = []
validated_progress = []
for i, (vid, prog) in enumerate(zip(processed_videos, progress_targets)):
if len(vid) != max_length:
logging.warning(f"Sample {i}: video length {len(vid)} != {max_length}, padding/trimming")
if len(vid) < max_length:
# Pad
padding = max_length - len(vid)
vid = torch.cat([vid, vid[-1:].repeat(padding, 1)])
prog = torch.cat([prog, torch.full((padding,), prog[-1], device=prog.device)])
else:
# Trim
vid = vid[:max_length]
prog = prog[:max_length]
validated_videos.append(vid)
validated_progress.append(prog)
# Stack into batches
processed_videos = torch.stack(processed_videos)
progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1)
processed_states = torch.stack(processed_states) if processed_states else None
# Stack processed features
processed_videos = torch.stack(validated_videos)
progress_targets = torch.stack(validated_progress)
# Ensure progress_targets has the same shape as progress_preds
# progress_preds is (batch_size, num_frames, 1)
# progress_targets is (batch_size, num_frames) -> add last dimension
if progress_targets.dim() == 2:
progress_targets = progress_targets.unsqueeze(-1) # (batch_size, num_frames, 1)
if state_features is not None and len(processed_states) > 0:
processed_states = torch.stack(processed_states)
else:
processed_states = None
# Get predictions
# Get model predictions
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
processed_videos, text_features, processed_states
)
# Use annotation-based progress targets if available, otherwise use computed ones
if progress_targets_from_annotations is not None and len(processed_videos) == 1:
# Use refined progress from subtask annotations (single sample case)
# Ensure shapes match
if progress_targets_from_annotations.shape != progress_preds.shape:
if progress_targets_from_annotations.dim() == 2:
progress_targets_from_annotations = progress_targets_from_annotations.unsqueeze(0)
progress_targets = progress_targets_from_annotations
# Use annotation-based progress targets
progress_from_annotations = observation.get('progress_targets')
if progress_from_annotations is not None:
progress_from_annotations = progress_from_annotations.to(self.device)
if progress_from_annotations.dim() == 2:
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
progress_targets = progress_from_annotations
# Compute progress loss using targets
# Compute progress loss
progress_loss = F.mse_loss(progress_preds, progress_targets)
output_dict = {'progress_loss': progress_loss.item()}
total_loss = progress_loss
# Compute stage loss if labels are available
stage_loss = None
if stage_labels is not None and len(processed_videos) == 1:
# Ensure stage_labels matches the sequence length
if stage_labels.dim() == 1 and stage_logits.dim() == 3:
# stage_labels: (seq_len,) -> need to expand to (batch, seq_len)
stage_labels = stage_labels.unsqueeze(0).expand(stage_logits.shape[0], -1)
elif stage_labels.shape[0] != stage_logits.shape[0]:
# Single label for batch - expand
stage_labels = stage_labels.expand(stage_logits.shape[0], stage_logits.shape[1])
# Compute cross-entropy loss for stage classification
# Compute stage loss if labels available
stage_labels = observation.get('stage_labels')
if stage_labels is not None:
stage_labels = stage_labels.to(self.device)
if stage_labels.dim() == 1:
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
stage_loss = compute_stage_loss(stage_logits, stage_labels)
# Combine losses
if stage_loss is not None:
total_loss = progress_loss + self.config.stage_loss_weight * stage_loss
output_dict = {
'progress_loss': progress_loss.item(),
'stage_loss': stage_loss.item(),
}
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item()
else:
total_loss = progress_loss
output_dict = {
'progress_loss': progress_loss.item(),
}
raise ValueError("Stage labels are None, but are required for stage loss")
# Compute misaligned loss (following SARM paper and ReWiND)
# "To improve video-language alignment, task descriptions are occasionally perturbed"
if random.random() < 0.2: # 20% probability (matching ReWiND)
# Create misaligned pairs by shuffling text features
# Misaligned loss: 20% probability (SARM paper - improve video-language alignment)
if random.random() < 0.2:
shuffle_idx = torch.randperm(batch_size, device=self.device)
misaligned_texts = text_features[shuffle_idx]
# Get predictions for misaligned pairs (should predict zero progress)
_, _, misaligned_preds = self.sarm_transformer(
processed_videos, misaligned_texts, processed_states
processed_videos, text_features[shuffle_idx], processed_states
)
# Target is zero progress for misaligned pairs
target_zeros = torch.zeros_like(misaligned_preds)
misaligned_loss = F.mse_loss(misaligned_preds, target_zeros)
# Add to total loss
misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds))
total_loss = total_loss + misaligned_loss
output_dict['misaligned_loss'] = misaligned_loss.item()
# RA-BC weighted loss (if enabled)
if self.config.enable_rabc:
# Compute progress deltas (simplified: use consecutive frame differences)
progress_deltas = progress_preds[:, 1:, 0] - progress_preds[:, :-1, 0]
progress_deltas = progress_deltas.mean(dim=1) # Average over sequence
# Update running statistics
self._update_rabc_stats(progress_deltas)
# Compute weights
weights = self._compute_rabc_weights(progress_deltas)
# Apply weighted loss
weighted_loss = (total_loss * weights.mean()).sum()
total_loss = weighted_loss
# Add final total loss to output dict
output_dict['total_loss'] = total_loss.item()
return total_loss, output_dict
# Loss utilities
def compute_stage_loss(
stage_logits: torch.Tensor,
target_stages: torch.Tensor
) -> torch.Tensor:
"""
Compute stage classification loss.
Args:
stage_logits: Stage predictions (batch_size, num_frames, num_stages)
target_stages: Target stage indices (batch_size, num_frames)
Returns:
Cross-entropy loss
"""
batch_size, num_frames, num_stages = stage_logits.shape
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
_, _, num_stages = stage_logits.shape
stage_logits_flat = stage_logits.reshape(-1, num_stages)
target_stages_flat = target_stages.reshape(-1)
@@ -840,20 +690,7 @@ def compute_stage_loss(
return loss
def compute_progress_loss(
progress_preds: torch.Tensor,
target_progress: torch.Tensor
) -> torch.Tensor:
"""
Compute progress regression loss.
Args:
progress_preds: Progress predictions (batch_size, num_frames, 1)
target_progress: Target progress values (batch_size, num_frames, 1)
Returns:
Mean squared error loss
"""
def compute_progress_loss(progress_preds: torch.Tensor, target_progress: torch.Tensor) -> torch.Tensor:
loss = F.mse_loss(progress_preds, target_progress)
return loss
+271 -339
View File
@@ -15,10 +15,12 @@
# limitations under the License.
import logging
from typing import Dict, Any, List, Optional
from typing import Any
import numpy as np
import torch
from PIL import Image
import pandas as pd
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.processor import (
@@ -68,16 +70,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Compute temporal proportions from subtask annotations if available
self.temporal_proportions = None
self.subtask_names = None
if dataset_meta is not None and config.use_subtask_annotations:
if dataset_meta is not None:
self._compute_temporal_proportions()
# Initialize encoders
self._init_encoders()
def _init_encoders(self):
"""Initialize CLIP and MiniLM encoders."""
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
device = torch.device(
self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu"
@@ -116,13 +115,11 @@ class SARMEncodingProcessorStep(ProcessorStep):
logging.info("No subtask annotations found in dataset")
return
# Convert to pandas for easier processing
import pandas as pd
# Convert to pandas
episodes_df = episodes.to_pandas()
# Collect all subtask names and compute average durations
subtask_durations = {}
subtask_counts = {}
all_subtask_names = set()
for ep_idx in episodes_df.index:
@@ -178,44 +175,166 @@ class SARMEncodingProcessorStep(ProcessorStep):
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
"""Generate stage labels and refined progress targets from subtask annotations.
def _to_numpy_array(self, x) -> np.ndarray:
"""Convert input to a 1D numpy array."""
if isinstance(x, torch.Tensor):
arr = x.cpu().numpy()
else:
arr = np.array(x)
if arr.ndim == 0:
arr = np.array([arr.item()])
return arr
def _find_episode_for_frame(self, frame_idx: int) -> int:
"""Find the episode index for a given frame index."""
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
return ep_idx
return 0 # Fallback
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
"""Get episode indices for each frame index."""
if episode_index is None:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
episode_indices = self._to_numpy_array(episode_index)
# If single episode but multiple frames, compute episode for each frame
if len(episode_indices) == 1 and len(frame_indices) > 1:
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
return episode_indices
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for a sequence."""
frame_gap = getattr(self.config, 'frame_gap', 1)
if frame_gap > 1:
indices = [max(ep_start, frame_idx - (num_frames - 1 - i) * frame_gap) for i in range(num_frames)]
return torch.tensor(indices)
else:
start_idx = max(ep_start, frame_idx - num_frames + 1)
return torch.arange(start_idx, frame_idx + 1)
def _compute_episode_metadata(
self,
frame_indices: np.ndarray,
episode_indices: np.ndarray,
num_frames: int,
is_batch: bool,
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute episode metadata for all samples.
Args:
frame_index: Current frame index or indices
episode_index: Episode index
video_features: Video features tensor to determine sequence length
Returns:
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations
Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths)
"""
if self.temporal_proportions is None or episode_index is None:
return None, None
absolute_indices_list = []
remaining_lengths = []
episode_lengths = []
# Convert to pandas to access annotations
import pandas as pd
episodes_df = self.dataset_meta.episodes.to_pandas()
# Handle batch processing
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_lengths.append(ep_end - ep_start)
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, num_frames)
absolute_indices_list.append(abs_indices)
remaining_lengths.append(ep_end - abs_indices[0].item())
if is_batch:
# Process multiple samples - for now, return None
# (batch processing of annotations is complex and not critical)
return None, None
# Single sample processing
if isinstance(episode_index, torch.Tensor):
ep_idx = int(episode_index.item())
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
else:
ep_idx = int(episode_index)
return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0]
def _compute_stage_and_progress_for_frame(
self,
current_frame: int,
subtask_names: list,
subtask_start_frames: list,
subtask_end_frames: list,
) -> tuple[int, float]:
"""Compute stage index and cumulative progress for a single frame.
if isinstance(frame_index, torch.Tensor):
frame_idx = int(frame_index.item())
Args:
current_frame: Frame index relative to episode start
subtask_names: List of subtask names for this episode
subtask_start_frames: List of subtask start frames
subtask_end_frames: List of subtask end frames
Returns:
Tuple of (stage_idx, cumulative_progress)
"""
stage_idx = -1
cumulative_progress = 0.0
# Find which subtask this frame belongs to
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Calculate within-subtask progress
subtask_duration = end_frame - start_frame
if subtask_duration > 0:
within_subtask_progress = (current_frame - start_frame) / subtask_duration
else:
within_subtask_progress = 1.0
# Calculate cumulative progress from completed subtasks
for k in range(j):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
# Add current subtask's partial progress
if name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
return stage_idx, cumulative_progress
# No matching subtask found - estimate based on position
if current_frame < subtask_start_frames[0]:
return 0, 0.0
elif current_frame > subtask_end_frames[-1]:
return len(self.subtask_names) - 1, 1.0
else:
frame_idx = int(frame_index)
# Between subtasks - use previous subtask's end state
for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Sum up all completed subtasks
for k in range(j + 1):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
return stage_idx, cumulative_progress
# Get subtask annotations for this episode
return 0, 0.0 # Fallback
def _compute_labels_for_sample(
self,
frame_idx: int,
ep_idx: int,
seq_len: int,
episodes_df: pd.DataFrame,
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
"""Compute stage labels and progress targets for a single sample.
Args:
frame_idx: The frame index for this sample
ep_idx: The episode index
seq_len: Number of frames in the sequence
episodes_df: DataFrame with episode metadata
Returns:
Tuple of (stage_labels, progress_targets) tensors with shapes (T,) and (T, 1),
or (None, None) if no valid annotations
"""
# Check if episode has valid annotations
if ep_idx >= len(episodes_df):
return None, None
@@ -228,21 +347,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Get episode boundaries
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
# Determine sequence length
if video_features is not None and video_features.dim() > 0:
seq_len = video_features.shape[0] if video_features.dim() == 2 else video_features.shape[1]
else:
seq_len = 1
# Get frame gap for temporal sampling
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
# Generate labels for each frame in the sequence
stage_labels = []
progress_targets = []
# Get frame gap for temporal sampling
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
for i in range(seq_len):
# Calculate actual frame index for this position in sequence
if frame_gap > 1:
@@ -251,326 +363,147 @@ class SARMEncodingProcessorStep(ProcessorStep):
else:
current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start)
# Find which subtask this frame belongs to
stage_idx = -1
within_subtask_progress = 0.0
cumulative_progress = 0.0
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Calculate within-subtask progress
subtask_duration = end_frame - start_frame
if subtask_duration > 0:
within_subtask_progress = (current_frame - start_frame) / subtask_duration
else:
within_subtask_progress = 1.0
# Calculate cumulative progress
for k in range(j):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
# Add current subtask's partial progress
if name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
break
# If no matching subtask found, estimate based on position
if stage_idx == -1:
# Estimate stage based on frame position
if current_frame < subtask_start_frames[0]:
stage_idx = 0
cumulative_progress = 0.0
elif current_frame > subtask_end_frames[-1]:
stage_idx = len(self.subtask_names) - 1
cumulative_progress = 1.0
else:
# Between subtasks - use previous subtask's end state
for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Sum up all previous subtasks
for k in range(j + 1):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
break
stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame(
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
)
stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress)
# Convert to tensors
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) # Add channel dim
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
return stage_labels, progress_targets
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
"""Generate stage labels and refined progress targets from subtask annotations.
Args:
frame_index: Current frame index or tensor of indices
episode_index: Episode index or tensor of indices
video_features: Video features tensor to determine sequence length
Returns:
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations.
"""
if self.temporal_proportions is None or episode_index is None:
return None, None
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
# Normalize inputs to numpy arrays
frame_indices = self._to_numpy_array(frame_index)
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length
if video_features is not None and video_features.dim() >= 2:
seq_len = video_features.shape[1] if is_batch else video_features.shape[0]
else:
seq_len = 1
episodes_df = self.dataset_meta.episodes.to_pandas()
# Process all samples
all_stage_labels = []
all_progress_targets = []
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
result = self._compute_labels_for_sample(int(frame_idx), int(ep_idx), seq_len, episodes_df)
if result[0] is None:
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
else:
all_stage_labels.append(result[0])
all_progress_targets.append(result[1])
if is_batch:
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
return all_stage_labels[0], all_progress_targets[0]
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition."""
from lerobot.processor.core import TransitionKey
self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
new_transition = self._current_transition
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or not isinstance(observation, dict):
return new_transition
if not isinstance(observation, dict):
raise ValueError("Observation must be a dictionary")
# Extract and encode images
batch_size = 1
if self.image_key in observation:
image = observation[self.image_key]
# Handle different image formats
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
# Encode images
video_features = self._encode_images_batch(image)
observation['video_features'] = video_features
# Get batch size from encoded features
batch_size = video_features.shape[0]
# 1. Encode images with CLIP
image = observation.get(self.image_key)
if image is None:
raise ValueError(f"Image not found in observation for key: {self.image_key}")
# Extract and normalize joint states
if self.config.use_joint_state:
# Look for "state" or "observation.state" in observation
state_key = None
state_data = None
if "state" in observation:
state_key = "state"
state_data = observation["state"]
elif "observation.state" in observation:
state_key = "observation.state"
state_data = observation["observation.state"]
if state_data is not None:
if isinstance(state_data, torch.Tensor):
state_data = state_data.cpu().numpy()
# Normalize if stats available
if self.dataset_stats and state_key in self.dataset_stats:
mean = self.dataset_stats[state_key]['mean']
std = self.dataset_stats[state_key]['std']
state_data = (state_data - mean) / (std + 1e-8)
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
else:
# Create dummy state features if not found
if 'video_features' in observation:
num_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1]
observation['state_features'] = torch.zeros(batch_size, num_frames, self.config.state_dim)
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
video_features = self._encode_images_batch(image)
observation['video_features'] = video_features
# Get task descriptions
task_descriptions = None
if 'task' in new_transition:
task_descriptions = new_transition['task']
# 2. Extract and normalize joint states
state_data = observation.get("state") or observation.get("observation.state")
if state_data is None:
raise ValueError("State data not found in observation (expected 'state' or 'observation.state')")
if isinstance(state_data, torch.Tensor):
state_data = state_data.cpu().numpy()
state_key = "state" if "state" in observation else "observation.state"
if self.dataset_stats and state_key in self.dataset_stats:
mean = self.dataset_stats[state_key]['mean']
std = self.dataset_stats[state_key]['std']
state_data = (state_data - mean) / (std + 1e-8)
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
# 3. Encode text with MiniLM
batch_size = video_features.shape[0]
task_descriptions = new_transition.get('task')
if task_descriptions is not None:
if isinstance(task_descriptions, str):
task_descriptions = [task_descriptions] * batch_size
# Encode text
if task_descriptions is not None:
text_features = self._encode_text_batch_list(task_descriptions)
observation['text_features'] = self._encode_text_batch_list(task_descriptions)
else:
text_features = self._encode_text_batch(self.task_description, batch_size)
observation['text_features'] = self._encode_text_batch(self.task_description, batch_size)
observation['text_features'] = text_features
# 4. Extract frame/episode indices from complementary data
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
# Compute episode metadata for progress normalization
# Note: Processor runs BEFORE batching, so we need to extract from raw dataset structure
# The dataset provides episode_index and index in the raw item
frame_index = comp_data.get('index')
episode_index = comp_data.get('episode_index')
# Extract index and episode_index from COMPLEMENTARY_DATA
episode_index = None
frame_index = None
if frame_index is None:
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
if episode_index is None:
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
# Primary location: COMPLEMENTARY_DATA (confirmed from debug logs)
if TransitionKey.COMPLEMENTARY_DATA in new_transition:
comp_data = new_transition[TransitionKey.COMPLEMENTARY_DATA]
if isinstance(comp_data, dict):
frame_index = comp_data.get('index')
episode_index = comp_data.get('episode_index')
# Fallback: check other locations
if frame_index is None and TransitionKey.OBSERVATION in new_transition:
obs = new_transition[TransitionKey.OBSERVATION]
if isinstance(obs, dict):
frame_index = obs.get('index')
if episode_index is None:
episode_index = obs.get('episode_index')
# If we have frame_index but no episode_index, compute it from episode boundaries
if frame_index is not None and episode_index is None and self.dataset_meta is not None:
# Convert to int if needed
if isinstance(frame_index, torch.Tensor):
frame_idx = frame_index.item() if frame_index.numel() == 1 else frame_index[0].item()
else:
frame_idx = int(frame_index)
# Search through episodes to find which one this frame belongs to
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
episode_index = ep_idx
break
if self.dataset_meta is not None and frame_index is not None:
# Handle batch processing
# 5. Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None:
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
frame_indices = self._to_numpy_array(frame_index)
episode_indices = self._get_episode_indices(frame_indices, episode_index)
if is_batch:
# Batch case: process multiple samples at once
batch_size = frame_index.shape[0]
frame_indices = frame_index.cpu().numpy() if isinstance(frame_index, torch.Tensor) else np.array(frame_index)
# Ensure at least 1D
if frame_indices.ndim == 0:
frame_indices = np.array([frame_indices.item()])
# Compute episode_index for each frame if not provided
if episode_index is None:
episode_indices = []
for frame_idx in frame_indices:
frame_idx = int(frame_idx)
# Search through episodes
found = False
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
episode_indices.append(ep_idx)
found = True
break
if not found:
episode_indices.append(0) # Fallback
episode_indices = np.array(episode_indices)
else:
episode_indices = episode_index.cpu().numpy() if isinstance(episode_index, torch.Tensor) else np.array(episode_index)
# Ensure at least 1D
if episode_indices.ndim == 0:
episode_indices = np.array([episode_indices.item()])
# CRITICAL FIX: If we have a single episode_index but multiple frame_indices,
# compute the correct episode for each frame (they might be from different episodes)
if len(episode_indices) == 1 and len(frame_indices) > 1:
episode_indices = []
for frame_idx in frame_indices:
frame_idx = int(frame_idx)
# Search through episodes
found = False
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
episode_indices.append(ep_idx)
found = True
break
if not found:
episode_indices.append(0) # Fallback
episode_indices = np.array(episode_indices)
# Compute metadata for each sample in batch
absolute_indices_list = []
remaining_lengths = []
episode_lengths = []
# Convert to list for safe iteration
episode_indices_list = episode_indices.tolist() if hasattr(episode_indices, 'tolist') else list(episode_indices)
frame_indices_list = frame_indices.tolist() if hasattr(frame_indices, 'tolist') else list(frame_indices)
for i, (ep_idx, frame_idx) in enumerate(zip(episode_indices_list, frame_indices_list)):
ep_idx = int(ep_idx)
frame_idx = int(frame_idx)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_length = ep_end - ep_start
episode_lengths.append(episode_length)
# Compute absolute indices for this sample
if 'video_features' in observation and observation['video_features'].dim() > 1:
num_loaded_frames = observation['video_features'].shape[1] # (batch, seq_len, features)
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
if frame_gap > 1:
absolute_indices = []
for j in range(num_loaded_frames):
offset = -(num_loaded_frames - 1 - j) * frame_gap
idx = max(ep_start, frame_idx + offset)
absolute_indices.append(idx)
absolute_indices = torch.tensor(absolute_indices)
else:
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
absolute_indices = torch.arange(start_idx, frame_idx + 1)
absolute_indices_list.append(absolute_indices)
remaining_lengths.append(ep_end - absolute_indices[0].item())
else:
absolute_indices_list.append(torch.tensor([frame_idx]))
remaining_lengths.append(ep_end - frame_idx)
observation['absolute_frame_indices'] = absolute_indices_list
observation['remaining_length'] = torch.tensor(remaining_lengths)
observation['episode_length'] = torch.tensor(episode_lengths)
# Determine number of frames from video features
if video_features.dim() >= 2:
num_frames = video_features.shape[1] if is_batch else video_features.shape[0]
else:
# Single sample case
if isinstance(frame_index, torch.Tensor):
frame_idx = frame_index.item()
else:
frame_idx = int(frame_index)
# Get episode_index
if episode_index is None:
# Search through episodes
for ep_idx in range(len(self.dataset_meta.episodes)):
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
if ep_start <= frame_idx < ep_end:
episode_index = ep_idx
break
if episode_index is None:
episode_index = 0 # Fallback
ep_idx = int(episode_index) if not isinstance(episode_index, int) else episode_index
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_length = ep_end - ep_start
# Compute absolute indices
if 'video_features' in observation and observation['video_features'].dim() > 0:
num_loaded_frames = observation['video_features'].shape[0]
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
if frame_gap > 1:
absolute_indices = []
for i in range(num_loaded_frames):
offset = -(num_loaded_frames - 1 - i) * frame_gap
idx = max(ep_start, frame_idx + offset)
absolute_indices.append(idx)
absolute_indices = torch.tensor(absolute_indices)
else:
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
absolute_indices = torch.arange(start_idx, frame_idx + 1)
observation['absolute_frame_indices'] = absolute_indices
observation['remaining_length'] = ep_end - absolute_indices[0].item()
else:
observation['absolute_frame_indices'] = torch.tensor([frame_idx])
observation['remaining_length'] = ep_end - frame_idx
observation['episode_length'] = episode_length
num_frames = 1
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
frame_indices, episode_indices, num_frames, is_batch
)
observation['absolute_frame_indices'] = abs_indices
observation['remaining_length'] = remaining
observation['episode_length'] = ep_lengths
# Generate stage labels and refined progress from subtask annotations
# 6. Generate stage labels and progress targets from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
frame_index, episode_index, observation.get('video_features')
frame_index, episode_index, video_features
)
if stage_labels is not None:
observation['stage_labels'] = stage_labels
@@ -714,11 +647,10 @@ class SARMEncodingProcessorStep(ProcessorStep):
type=FeatureType.LANGUAGE,
shape=(self.config.text_dim,)
)
if self.config.use_joint_state:
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.config.num_frames, self.config.state_dim)
)
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.config.num_frames, self.config.state_dim)
)
return features