Add uniform sampling and transition smoothing

This commit is contained in:
Pepijn
2025-11-28 17:15:57 +01:00
parent 6e3b972534
commit 112eb70a65
4 changed files with 237 additions and 90 deletions
+16 -6
View File
@@ -286,10 +286,17 @@ def run_inference(
state_slices = []
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices: [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
# The first delta is -100000 which clamps to episode start
# Compute frame indices using symmetric bidirectional pattern:
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
# Boundary handling: clamp to [0, last_valid]
deltas = model.config.observation_delta_indices
frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas]
last_valid = len(video_embeddings) - 1
frame_indices = []
for delta in deltas:
idx = current_frame + delta
idx = max(0, min(idx, last_valid)) # Clamp to valid range
frame_indices.append(idx)
video_slice = video_embeddings[frame_indices]
video_slices.append(video_slice)
@@ -324,9 +331,12 @@ def run_inference(
batch_video, batch_text, batch_states
)
# Extract last frame predictions (the "current" frame)
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
batch_stages = stage_probs[:, -1, :].cpu().numpy()
# Extract predictions at the "current frame" position
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
# the current frame is at position 5 (0-indexed)
current_frame_idx = 5
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
all_progress.extend(batch_progress)
all_stages.extend(batch_stages)
+35 -22
View File
@@ -17,8 +17,11 @@
"""
SARM Temporal Sampler for reward model training.
Samples frames from episodes ensuring sufficient temporal history for SARM's
9-frame pattern (1 initial + 8 consecutive with frame_gap spacing).
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
- 1 initial frame + 4 frames before + current + 3 frames after
Boundary handling: clamp to first/last frame when indices go out of bounds.
This enables truly uniform sampling across entire episodes.
"""
import logging
@@ -31,14 +34,18 @@ import random
class SARMTemporalSampler(Sampler):
"""
Temporal sampler for SARM reward model training.
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
SARM uses 9 frames per sample:
- Frame 0: Initial frame of the episode (always frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
- Frames 1-8: Symmetric context around current frame
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
This sampler ensures we only sample from positions that have enough
temporal history (at least 7 * frame_gap frames from episode start).
Boundary handling:
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
This enables truly uniform sampling across entire episodes.
Args:
dataset_from_index: Start indices of episodes (global dataset indices)
@@ -47,6 +54,7 @@ class SARMTemporalSampler(Sampler):
shuffle: Whether to shuffle sampling order
seed: Random seed for reproducibility
samples_per_epoch: Number of samples per epoch (default: 6400)
min_episode_length: Minimum episode length to include (default: 1)
"""
def __init__(
@@ -57,15 +65,14 @@ class SARMTemporalSampler(Sampler):
shuffle: bool = True,
seed: Optional[int] = None,
samples_per_epoch: int = 6400,
min_episode_length: int = 1,
):
self.dataset_from_index = np.array(dataset_from_index)
self.dataset_to_index = np.array(dataset_to_index)
self.frame_gap = frame_gap
self.shuffle = shuffle
self.samples_per_epoch = samples_per_epoch
# Minimum frames needed for SARM pattern: 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
self.min_frames_needed = 7 * frame_gap + 1
self.min_episode_length = min_episode_length
if seed is not None:
self.seed = seed
@@ -75,18 +82,23 @@ class SARMTemporalSampler(Sampler):
else:
self.generator = torch.Generator()
# Compute valid episodes and sampling positions
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
self._compute_valid_positions()
logging.info(
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
f"{len(self.all_valid_positions)} valid positions, "
f"{len(self.all_valid_positions)} positions (uniform sampling), "
f"{self.samples_per_epoch} samples per epoch, "
f"frame_gap={frame_gap}"
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
)
def _compute_valid_positions(self):
"""Compute valid episodes and all valid sampling positions."""
"""Compute valid episodes and ALL sampling positions for uniform sampling.
With symmetric bidirectional sampling, we can sample from ANY frame:
- Early frames: backward indices clamp to first frame
- Late frames: forward indices clamp to last frame
"""
self.valid_episodes = []
self.all_valid_positions = []
@@ -95,13 +107,12 @@ class SARMTemporalSampler(Sampler):
ep_end = self.dataset_to_index[ep_idx]
episode_length = ep_end - ep_start
# Episode must have enough frames for SARM pattern
if episode_length >= self.min_frames_needed:
# Include all episodes with at least min_episode_length frames
if episode_length >= self.min_episode_length:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
# Valid positions: from min_frames_needed to episode end
# These are global dataset indices
for pos in range(ep_start + self.min_frames_needed - 1, ep_end):
# Include ALL positions in the episode (truly uniform sampling)
for pos in range(ep_start, ep_end):
self.all_valid_positions.append(pos)
self.valid_episodes = np.array(self.valid_episodes)
@@ -110,8 +121,7 @@ class SARMTemporalSampler(Sampler):
if len(self.all_valid_positions) == 0:
raise ValueError(
f"No valid sampling positions found! "
f"Episodes need at least {self.min_frames_needed} frames "
f"(7 * frame_gap + 1 = 7 * {self.frame_gap} + 1)."
f"Check that episodes have at least {self.min_episode_length} frames."
)
def __len__(self) -> int:
@@ -119,12 +129,15 @@ class SARMTemporalSampler(Sampler):
def __iter__(self) -> Iterator[int]:
"""
Yields global dataset indices for sampling.
Yields global dataset indices for uniform sampling across episodes.
Each yielded index represents the "current frame" position.
The dataset's observation_delta_indices then handles loading:
- Frame 0: Episode initial frame (via large negative delta clamping)
- Frames 1-8: Consecutive frames ending at the yielded index
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
For early frames: backward indices clamp to first frame (progress ~0%)
For late frames: forward indices clamp to last frame (progress ~100%)
"""
if self.shuffle:
# Randomly sample from all valid positions
+25 -11
View File
@@ -141,24 +141,38 @@ class SARMConfig(PreTrainedConfig):
@property
def observation_delta_indices(self) -> list[int]:
"""Load frames for SARM temporal sampling.
"""Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern.
Per SARM paper (Section A.4), the model uses 9 frames:
- Frame 0: Initial frame of the episode
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
The model uses 9 frames with symmetric context around current frame:
- Frame 0: Initial frame of the episode (clamped via large negative delta)
- Frames 1-8: Symmetric context: 4 before + current + 3 after
The first delta uses a large negative offset (-1_000_000) that will be clamped
to the episode start (frame 0) by the dataset loader. This ensures we always
get the initial frame regardless of the current position in the episode.
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling (done by dataset loader):
- Early frames: backward indices clamp to 0 (first frame)
- Late frames: forward indices clamp to episode end (last frame)
This enables truly uniform sampling across entire episodes.
Returns:
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap]
"""
initial_frame_delta = -1_000_000
num_consecutive = self.num_frames - 1 # 9 - 1 = 8
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0]
return [initial_frame_delta] + consecutive_deltas
# Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames
symmetric_deltas = [
-4 * self.frame_gap,
-3 * self.frame_gap,
-2 * self.frame_gap,
-1 * self.frame_gap,
0, # current frame
1 * self.frame_gap,
2 * self.frame_gap,
3 * self.frame_gap,
]
return [initial_frame_delta] + symmetric_deltas
@property
def action_delta_indices(self) -> None:
+161 -51
View File
@@ -92,25 +92,36 @@ class SARMEncodingProcessorStep(ProcessorStep):
return episode_indices
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for a sequence.
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor:
"""Compute absolute frame indices for symmetric bidirectional pattern.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (ep_start)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Backward indices clamp to ep_start (first frame)
- Forward indices clamp to ep_end - 1 (last frame)
"""
indices = []
indices.append(ep_start) # First frame is the episode's initial frame
# Remaining frames are consecutive with frame_gap spacing
num_consecutive = num_frames - 1
for i in range(num_consecutive):
offset = -(num_consecutive - 1 - i) * self.config.frame_gap
idx = max(ep_start, frame_idx + offset)
indices.append(ep_start) # Initial frame is always episode start
# Symmetric context: 4 before, current, 3 after
num_before = 4
num_after = 3
last_valid_frame = ep_end - 1
# Frames before current (clamp to first frame)
for i in range(num_before, 0, -1):
idx = max(ep_start, frame_idx - i * self.config.frame_gap)
indices.append(idx)
# Current frame
indices.append(frame_idx)
# Frames after current (clamp to last frame)
for i in range(1, num_after + 1):
idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap)
indices.append(idx)
return torch.tensor(indices)
def _compute_episode_metadata(
@@ -134,7 +145,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_lengths.append(ep_end - ep_start)
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, num_frames)
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames)
absolute_indices_list.append(abs_indices)
remaining_lengths.append(ep_end - abs_indices[0].item())
@@ -146,8 +157,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
subtask_names: list,
subtask_start_frames: list,
subtask_end_frames: list,
) -> tuple[int, float]:
"""Compute stage index and cumulative progress for a single frame.
transition_smoothing_frames: int = 15,
) -> tuple[int, float, dict[int, float] | None]:
"""Compute stage index, cumulative progress, and soft stage labels for a single frame.
Implements SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
@@ -157,19 +169,28 @@ class SARMEncodingProcessorStep(ProcessorStep):
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Additionally computes soft stage labels near transitions to mitigate discrete jumps
in the stage classifier. Near stage boundaries, labels are blended between adjacent
stages to encourage smoother predictions.
Args:
current_frame: Frame index relative to episode start
subtask_names: List of subtask names for this episode
subtask_start_frames: List of subtask start frames
subtask_end_frames: List of subtask end frames
transition_smoothing_frames: Number of frames over which to smooth labels near transitions
Returns:
Tuple of (stage_idx, cumulative_progress)
Tuple of (stage_idx, cumulative_progress, soft_stage_labels)
- stage_idx: Hard stage index (for compatibility)
- cumulative_progress: Progress value in [0, 1]
- soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition
"""
# Get temporal proportions as list for compute_cumulative_progress
temporal_proportions_list = [
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
]
num_stages = len(self.subtask_names)
# Find which subtask this frame belongs to
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
@@ -183,14 +204,34 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Compute cumulative progress using utility function (Paper Formula 2)
cumulative_progress = compute_cumulative_progress_batch(
tau, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress
)
# Compute soft stage labels near transitions
soft_stage_labels = None
frames_from_start = current_frame - start_frame
frames_to_end = end_frame - current_frame
if frames_from_start < transition_smoothing_frames and j > 0:
# Near start of stage - blend with previous stage
blend = frames_from_start / transition_smoothing_frames
prev_name = subtask_names[j - 1]
prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1)
soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend}
elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1:
# Near end of stage - blend with next stage
blend = frames_to_end / transition_smoothing_frames
next_name = subtask_names[j + 1]
next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1)
soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend}
return stage_idx, cumulative_progress, soft_stage_labels
# No matching subtask found
if current_frame < subtask_start_frames[0]:
return 0, 0.0
return 0, 0.0, None
elif current_frame > subtask_end_frames[-1]:
return len(self.subtask_names) - 1, 1.0
return len(self.subtask_names) - 1, 1.0, None
else:
# Between subtasks - use previous subtask's end state (tau = 1.0)
for j in range(len(subtask_names) - 1):
@@ -202,9 +243,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
cumulative_progress = compute_cumulative_progress_batch(
1.0, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress
return stage_idx, cumulative_progress, None
return 0, 0.0
return 0, 0.0, None
def _compute_labels_for_sample(
self,
@@ -212,12 +253,16 @@ class SARMEncodingProcessorStep(ProcessorStep):
ep_idx: int,
seq_len: int,
episodes_df: pd.DataFrame,
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
"""Compute stage labels and progress targets for a single sample.
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]:
"""Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern.
(per SARM paper Section A.4):
- Frame 0: Initial frame of episode (stage at frame 0, progress at frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Before episode start: clamp to frame 0 (progress ~0%)
- After episode end: clamp to last frame (progress ~100%)
Soft stage labels are computed near stage transitions to mitigate discrete jumps.
Args:
frame_idx: The frame index for this sample
@@ -226,50 +271,83 @@ class SARMEncodingProcessorStep(ProcessorStep):
episodes_df: DataFrame with episode metadata
Returns:
Tuple of (stage_labels, progress_targets) tensors with shapes (T,) and (T, 1),
or (None, None) if no valid annotations
Tuple of (stage_labels, progress_targets, soft_stage_labels):
- stage_labels: (T,) hard stage indices
- progress_targets: (T, 1) progress values
- soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby
"""
# Check if episode has valid annotations
if ep_idx >= len(episodes_df):
return None, None
return None, None, None
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
return None, None
return None, None, None
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
ep_length = ep_end - ep_start
last_valid_frame = ep_length - 1
num_stages = len(self.subtask_names)
# Generate labels for each frame in the sequence
stage_labels = []
progress_targets = []
soft_labels_list = [] # List of soft label dicts (or None)
has_any_soft_labels = False
# Symmetric pattern: initial + 4 before + current + 3 after = 9 frames
num_before = 4
num_after = 3
for i in range(seq_len):
if i == 0:
# Position 0: Initial frame of the episode
current_frame = 0 # Relative to episode start
else:
# Positions 1-8: consecutive frames with frame_gap spacing
num_consecutive = seq_len - 1
offset = -(num_consecutive - i) * self.config.frame_gap
elif i <= num_before:
# Positions 1-4: frames before current (with clamping to first frame)
offset = -(num_before - i + 1) * self.config.frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
elif i == num_before + 1:
# Position 5: current frame
current_frame = frame_idx - ep_start
else:
# Positions 6-8: frames after current (with clamping to last frame)
offset = (i - num_before - 1) * self.config.frame_gap
current_frame = min(last_valid_frame, frame_idx + offset - ep_start)
stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame(
stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame(
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
)
stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress)
soft_labels_list.append(soft_stage_labels)
if soft_stage_labels is not None:
has_any_soft_labels = True
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
return stage_labels, progress_targets
# Convert soft labels to tensor if any exist
soft_stage_labels_tensor = None
if has_any_soft_labels:
soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for i, soft_dict in enumerate(soft_labels_list):
if soft_dict is not None:
for stage_idx, prob in soft_dict.items():
soft_stage_labels_tensor[i, stage_idx] = prob
else:
# Use hard one-hot label
soft_stage_labels_tensor[i, stage_labels[i]] = 1.0
return stage_labels, progress_targets, soft_stage_labels_tensor
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
"""Generate stage labels and refined progress targets from subtask annotations.
"""Generate stage labels, progress targets, and soft stage labels from subtask annotations.
Args:
frame_index: Current frame index or tensor of indices
@@ -277,10 +355,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
video_features: Video features tensor to determine sequence length
Returns:
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations.
Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations.
- stage_labels: (B, T) hard stage indices
- progress_targets: (B, T, 1) progress values
- soft_stage_labels: (B, T, num_stages) soft probability labels, or None
"""
if self.temporal_proportions is None or episode_index is None:
return None, None
return None, None, None
# Normalize inputs to numpy arrays
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
@@ -293,21 +374,48 @@ class SARMEncodingProcessorStep(ProcessorStep):
seq_len = 1
episodes_df = self.dataset_meta.episodes.to_pandas()
num_stages = len(self.subtask_names)
all_stage_labels = []
all_progress_targets = []
all_soft_stage_labels = []
has_any_soft_labels = False
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
result = self._compute_labels_for_sample(int(frame_idx), int(ep_idx), seq_len, episodes_df)
stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample(
int(frame_idx), int(ep_idx), seq_len, episodes_df
)
if result[0] is None:
if stage_labels is None:
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
all_soft_stage_labels.append(None)
else:
all_stage_labels.append(result[0])
all_progress_targets.append(result[1])
all_stage_labels.append(stage_labels)
all_progress_targets.append(progress_targets)
all_soft_stage_labels.append(soft_labels)
if soft_labels is not None:
has_any_soft_labels = True
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
stacked_stage_labels = torch.stack(all_stage_labels, dim=0)
stacked_progress_targets = torch.stack(all_progress_targets, dim=0)
# Stack soft labels if any exist
stacked_soft_labels = None
if has_any_soft_labels:
soft_labels_tensors = []
for i, soft_labels in enumerate(all_soft_stage_labels):
if soft_labels is not None:
soft_labels_tensors.append(soft_labels)
else:
# Create one-hot from hard labels
one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32)
for t in range(seq_len):
one_hot[t, all_stage_labels[i][t]] = 1.0
soft_labels_tensors.append(one_hot)
stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0)
return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition."""
@@ -371,14 +479,16 @@ class SARMEncodingProcessorStep(ProcessorStep):
observation['remaining_length'] = remaining
observation['episode_length'] = ep_lengths
# Generate stage labels and progress targets from subtask annotations
# Generate stage labels, progress targets, and soft stage labels from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels(
frame_index, episode_index, video_features
)
if stage_labels is not None:
observation['stage_labels'] = stage_labels
observation['progress_targets'] = progress_targets
if soft_stage_labels is not None:
observation['soft_stage_labels'] = soft_stage_labels
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition