mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Add uniform sampling and transition smoothing
This commit is contained in:
@@ -286,10 +286,17 @@ def run_inference(
|
||||
state_slices = []
|
||||
|
||||
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Compute frame indices: [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
# The first delta is -100000 which clamps to episode start
|
||||
# Compute frame indices using symmetric bidirectional pattern:
|
||||
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
# Boundary handling: clamp to [0, last_valid]
|
||||
deltas = model.config.observation_delta_indices
|
||||
frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas]
|
||||
last_valid = len(video_embeddings) - 1
|
||||
|
||||
frame_indices = []
|
||||
for delta in deltas:
|
||||
idx = current_frame + delta
|
||||
idx = max(0, min(idx, last_valid)) # Clamp to valid range
|
||||
frame_indices.append(idx)
|
||||
|
||||
video_slice = video_embeddings[frame_indices]
|
||||
video_slices.append(video_slice)
|
||||
@@ -324,9 +331,12 @@ def run_inference(
|
||||
batch_video, batch_text, batch_states
|
||||
)
|
||||
|
||||
# Extract last frame predictions (the "current" frame)
|
||||
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
|
||||
batch_stages = stage_probs[:, -1, :].cpu().numpy()
|
||||
# Extract predictions at the "current frame" position
|
||||
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
|
||||
# the current frame is at position 5 (0-indexed)
|
||||
current_frame_idx = 5
|
||||
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
|
||||
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
|
||||
|
||||
all_progress.extend(batch_progress)
|
||||
all_stages.extend(batch_stages)
|
||||
|
||||
@@ -17,8 +17,11 @@
|
||||
"""
|
||||
SARM Temporal Sampler for reward model training.
|
||||
|
||||
Samples frames from episodes ensuring sufficient temporal history for SARM's
|
||||
9-frame pattern (1 initial + 8 consecutive with frame_gap spacing).
|
||||
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
|
||||
- 1 initial frame + 4 frames before + current + 3 frames after
|
||||
|
||||
Boundary handling: clamp to first/last frame when indices go out of bounds.
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -31,14 +34,18 @@ import random
|
||||
|
||||
class SARMTemporalSampler(Sampler):
|
||||
"""
|
||||
Temporal sampler for SARM reward model training.
|
||||
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
|
||||
|
||||
SARM uses 9 frames per sample:
|
||||
- Frame 0: Initial frame of the episode (always frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
- Frames 1-8: Symmetric context around current frame
|
||||
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
This sampler ensures we only sample from positions that have enough
|
||||
temporal history (at least 7 * frame_gap frames from episode start).
|
||||
Boundary handling:
|
||||
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
|
||||
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
|
||||
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
|
||||
Args:
|
||||
dataset_from_index: Start indices of episodes (global dataset indices)
|
||||
@@ -47,6 +54,7 @@ class SARMTemporalSampler(Sampler):
|
||||
shuffle: Whether to shuffle sampling order
|
||||
seed: Random seed for reproducibility
|
||||
samples_per_epoch: Number of samples per epoch (default: 6400)
|
||||
min_episode_length: Minimum episode length to include (default: 1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -57,15 +65,14 @@ class SARMTemporalSampler(Sampler):
|
||||
shuffle: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
samples_per_epoch: int = 6400,
|
||||
min_episode_length: int = 1,
|
||||
):
|
||||
self.dataset_from_index = np.array(dataset_from_index)
|
||||
self.dataset_to_index = np.array(dataset_to_index)
|
||||
self.frame_gap = frame_gap
|
||||
self.shuffle = shuffle
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
|
||||
# Minimum frames needed for SARM pattern: 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
|
||||
self.min_frames_needed = 7 * frame_gap + 1
|
||||
self.min_episode_length = min_episode_length
|
||||
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
@@ -75,18 +82,23 @@ class SARMTemporalSampler(Sampler):
|
||||
else:
|
||||
self.generator = torch.Generator()
|
||||
|
||||
# Compute valid episodes and sampling positions
|
||||
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
|
||||
self._compute_valid_positions()
|
||||
|
||||
logging.info(
|
||||
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
|
||||
f"{len(self.all_valid_positions)} valid positions, "
|
||||
f"{len(self.all_valid_positions)} positions (uniform sampling), "
|
||||
f"{self.samples_per_epoch} samples per epoch, "
|
||||
f"frame_gap={frame_gap}"
|
||||
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
|
||||
)
|
||||
|
||||
def _compute_valid_positions(self):
|
||||
"""Compute valid episodes and all valid sampling positions."""
|
||||
"""Compute valid episodes and ALL sampling positions for uniform sampling.
|
||||
|
||||
With symmetric bidirectional sampling, we can sample from ANY frame:
|
||||
- Early frames: backward indices clamp to first frame
|
||||
- Late frames: forward indices clamp to last frame
|
||||
"""
|
||||
self.valid_episodes = []
|
||||
self.all_valid_positions = []
|
||||
|
||||
@@ -95,13 +107,12 @@ class SARMTemporalSampler(Sampler):
|
||||
ep_end = self.dataset_to_index[ep_idx]
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Episode must have enough frames for SARM pattern
|
||||
if episode_length >= self.min_frames_needed:
|
||||
# Include all episodes with at least min_episode_length frames
|
||||
if episode_length >= self.min_episode_length:
|
||||
self.valid_episodes.append((ep_idx, ep_start, ep_end))
|
||||
|
||||
# Valid positions: from min_frames_needed to episode end
|
||||
# These are global dataset indices
|
||||
for pos in range(ep_start + self.min_frames_needed - 1, ep_end):
|
||||
# Include ALL positions in the episode (truly uniform sampling)
|
||||
for pos in range(ep_start, ep_end):
|
||||
self.all_valid_positions.append(pos)
|
||||
|
||||
self.valid_episodes = np.array(self.valid_episodes)
|
||||
@@ -110,8 +121,7 @@ class SARMTemporalSampler(Sampler):
|
||||
if len(self.all_valid_positions) == 0:
|
||||
raise ValueError(
|
||||
f"No valid sampling positions found! "
|
||||
f"Episodes need at least {self.min_frames_needed} frames "
|
||||
f"(7 * frame_gap + 1 = 7 * {self.frame_gap} + 1)."
|
||||
f"Check that episodes have at least {self.min_episode_length} frames."
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -119,12 +129,15 @@ class SARMTemporalSampler(Sampler):
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""
|
||||
Yields global dataset indices for sampling.
|
||||
Yields global dataset indices for uniform sampling across episodes.
|
||||
|
||||
Each yielded index represents the "current frame" position.
|
||||
The dataset's observation_delta_indices then handles loading:
|
||||
- Frame 0: Episode initial frame (via large negative delta clamping)
|
||||
- Frames 1-8: Consecutive frames ending at the yielded index
|
||||
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
|
||||
|
||||
For early frames: backward indices clamp to first frame (progress ~0%)
|
||||
For late frames: forward indices clamp to last frame (progress ~100%)
|
||||
"""
|
||||
if self.shuffle:
|
||||
# Randomly sample from all valid positions
|
||||
|
||||
@@ -141,24 +141,38 @@ class SARMConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Load frames for SARM temporal sampling.
|
||||
"""Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern.
|
||||
|
||||
Per SARM paper (Section A.4), the model uses 9 frames:
|
||||
- Frame 0: Initial frame of the episode
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
The model uses 9 frames with symmetric context around current frame:
|
||||
- Frame 0: Initial frame of the episode (clamped via large negative delta)
|
||||
- Frames 1-8: Symmetric context: 4 before + current + 3 after
|
||||
|
||||
The first delta uses a large negative offset (-1_000_000) that will be clamped
|
||||
to the episode start (frame 0) by the dataset loader. This ensures we always
|
||||
get the initial frame regardless of the current position in the episode.
|
||||
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling (done by dataset loader):
|
||||
- Early frames: backward indices clamp to 0 (first frame)
|
||||
- Late frames: forward indices clamp to episode end (last frame)
|
||||
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
|
||||
Returns:
|
||||
9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0]
|
||||
9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap]
|
||||
"""
|
||||
initial_frame_delta = -1_000_000
|
||||
|
||||
num_consecutive = self.num_frames - 1 # 9 - 1 = 8
|
||||
consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0]
|
||||
return [initial_frame_delta] + consecutive_deltas
|
||||
# Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames
|
||||
symmetric_deltas = [
|
||||
-4 * self.frame_gap,
|
||||
-3 * self.frame_gap,
|
||||
-2 * self.frame_gap,
|
||||
-1 * self.frame_gap,
|
||||
0, # current frame
|
||||
1 * self.frame_gap,
|
||||
2 * self.frame_gap,
|
||||
3 * self.frame_gap,
|
||||
]
|
||||
|
||||
return [initial_frame_delta] + symmetric_deltas
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
|
||||
@@ -92,25 +92,36 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
return episode_indices
|
||||
|
||||
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor:
|
||||
"""Compute absolute frame indices for a sequence.
|
||||
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor:
|
||||
"""Compute absolute frame indices for symmetric bidirectional pattern.
|
||||
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of the episode (ep_start)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
|
||||
Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Backward indices clamp to ep_start (first frame)
|
||||
- Forward indices clamp to ep_end - 1 (last frame)
|
||||
"""
|
||||
indices = []
|
||||
indices.append(ep_start) # First frame is the episode's initial frame
|
||||
|
||||
# Remaining frames are consecutive with frame_gap spacing
|
||||
num_consecutive = num_frames - 1
|
||||
for i in range(num_consecutive):
|
||||
offset = -(num_consecutive - 1 - i) * self.config.frame_gap
|
||||
idx = max(ep_start, frame_idx + offset)
|
||||
indices.append(ep_start) # Initial frame is always episode start
|
||||
|
||||
# Symmetric context: 4 before, current, 3 after
|
||||
num_before = 4
|
||||
num_after = 3
|
||||
last_valid_frame = ep_end - 1
|
||||
|
||||
# Frames before current (clamp to first frame)
|
||||
for i in range(num_before, 0, -1):
|
||||
idx = max(ep_start, frame_idx - i * self.config.frame_gap)
|
||||
indices.append(idx)
|
||||
|
||||
|
||||
# Current frame
|
||||
indices.append(frame_idx)
|
||||
|
||||
# Frames after current (clamp to last frame)
|
||||
for i in range(1, num_after + 1):
|
||||
idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap)
|
||||
indices.append(idx)
|
||||
|
||||
return torch.tensor(indices)
|
||||
|
||||
def _compute_episode_metadata(
|
||||
@@ -134,7 +145,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
|
||||
episode_lengths.append(ep_end - ep_start)
|
||||
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, num_frames)
|
||||
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames)
|
||||
absolute_indices_list.append(abs_indices)
|
||||
remaining_lengths.append(ep_end - abs_indices[0].item())
|
||||
|
||||
@@ -146,8 +157,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
subtask_names: list,
|
||||
subtask_start_frames: list,
|
||||
subtask_end_frames: list,
|
||||
) -> tuple[int, float]:
|
||||
"""Compute stage index and cumulative progress for a single frame.
|
||||
transition_smoothing_frames: int = 15,
|
||||
) -> tuple[int, float, dict[int, float] | None]:
|
||||
"""Compute stage index, cumulative progress, and soft stage labels for a single frame.
|
||||
|
||||
Implements SARM Paper Formula (2):
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
@@ -157,19 +169,28 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k
|
||||
|
||||
Additionally computes soft stage labels near transitions to mitigate discrete jumps
|
||||
in the stage classifier. Near stage boundaries, labels are blended between adjacent
|
||||
stages to encourage smoother predictions.
|
||||
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
subtask_names: List of subtask names for this episode
|
||||
subtask_start_frames: List of subtask start frames
|
||||
subtask_end_frames: List of subtask end frames
|
||||
transition_smoothing_frames: Number of frames over which to smooth labels near transitions
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_idx, cumulative_progress)
|
||||
Tuple of (stage_idx, cumulative_progress, soft_stage_labels)
|
||||
- stage_idx: Hard stage index (for compatibility)
|
||||
- cumulative_progress: Progress value in [0, 1]
|
||||
- soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition
|
||||
"""
|
||||
# Get temporal proportions as list for compute_cumulative_progress
|
||||
temporal_proportions_list = [
|
||||
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
|
||||
]
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
# Find which subtask this frame belongs to
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
@@ -183,14 +204,34 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Compute cumulative progress using utility function (Paper Formula 2)
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
tau, stage_idx, temporal_proportions_list
|
||||
)
|
||||
return stage_idx, cumulative_progress
|
||||
)
|
||||
|
||||
# Compute soft stage labels near transitions
|
||||
soft_stage_labels = None
|
||||
frames_from_start = current_frame - start_frame
|
||||
frames_to_end = end_frame - current_frame
|
||||
|
||||
if frames_from_start < transition_smoothing_frames and j > 0:
|
||||
# Near start of stage - blend with previous stage
|
||||
blend = frames_from_start / transition_smoothing_frames
|
||||
prev_name = subtask_names[j - 1]
|
||||
prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1)
|
||||
soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend}
|
||||
|
||||
elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1:
|
||||
# Near end of stage - blend with next stage
|
||||
blend = frames_to_end / transition_smoothing_frames
|
||||
next_name = subtask_names[j + 1]
|
||||
next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1)
|
||||
soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend}
|
||||
|
||||
return stage_idx, cumulative_progress, soft_stage_labels
|
||||
|
||||
# No matching subtask found
|
||||
if current_frame < subtask_start_frames[0]:
|
||||
return 0, 0.0
|
||||
return 0, 0.0, None
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
return len(self.subtask_names) - 1, 1.0
|
||||
return len(self.subtask_names) - 1, 1.0, None
|
||||
else:
|
||||
# Between subtasks - use previous subtask's end state (tau = 1.0)
|
||||
for j in range(len(subtask_names) - 1):
|
||||
@@ -202,9 +243,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
1.0, stage_idx, temporal_proportions_list
|
||||
)
|
||||
return stage_idx, cumulative_progress
|
||||
return stage_idx, cumulative_progress, None
|
||||
|
||||
return 0, 0.0
|
||||
return 0, 0.0, None
|
||||
|
||||
def _compute_labels_for_sample(
|
||||
self,
|
||||
@@ -212,12 +253,16 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
ep_idx: int,
|
||||
seq_len: int,
|
||||
episodes_df: pd.DataFrame,
|
||||
) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]:
|
||||
"""Compute stage labels and progress targets for a single sample.
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]:
|
||||
"""Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern.
|
||||
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of episode (stage at frame 0, progress at frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Before episode start: clamp to frame 0 (progress ~0%)
|
||||
- After episode end: clamp to last frame (progress ~100%)
|
||||
|
||||
Soft stage labels are computed near stage transitions to mitigate discrete jumps.
|
||||
|
||||
Args:
|
||||
frame_idx: The frame index for this sample
|
||||
@@ -226,50 +271,83 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
episodes_df: DataFrame with episode metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets) tensors with shapes (T,) and (T, 1),
|
||||
or (None, None) if no valid annotations
|
||||
Tuple of (stage_labels, progress_targets, soft_stage_labels):
|
||||
- stage_labels: (T,) hard stage indices
|
||||
- progress_targets: (T, 1) progress values
|
||||
- soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby
|
||||
"""
|
||||
# Check if episode has valid annotations
|
||||
if ep_idx >= len(episodes_df):
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
ep_length = ep_end - ep_start
|
||||
last_valid_frame = ep_length - 1
|
||||
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
# Generate labels for each frame in the sequence
|
||||
stage_labels = []
|
||||
progress_targets = []
|
||||
soft_labels_list = [] # List of soft label dicts (or None)
|
||||
has_any_soft_labels = False
|
||||
|
||||
# Symmetric pattern: initial + 4 before + current + 3 after = 9 frames
|
||||
num_before = 4
|
||||
num_after = 3
|
||||
|
||||
for i in range(seq_len):
|
||||
if i == 0:
|
||||
# Position 0: Initial frame of the episode
|
||||
current_frame = 0 # Relative to episode start
|
||||
else:
|
||||
# Positions 1-8: consecutive frames with frame_gap spacing
|
||||
num_consecutive = seq_len - 1
|
||||
offset = -(num_consecutive - i) * self.config.frame_gap
|
||||
elif i <= num_before:
|
||||
# Positions 1-4: frames before current (with clamping to first frame)
|
||||
offset = -(num_before - i + 1) * self.config.frame_gap
|
||||
current_frame = max(0, frame_idx + offset - ep_start)
|
||||
|
||||
elif i == num_before + 1:
|
||||
# Position 5: current frame
|
||||
current_frame = frame_idx - ep_start
|
||||
else:
|
||||
# Positions 6-8: frames after current (with clamping to last frame)
|
||||
offset = (i - num_before - 1) * self.config.frame_gap
|
||||
current_frame = min(last_valid_frame, frame_idx + offset - ep_start)
|
||||
|
||||
stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame(
|
||||
stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame(
|
||||
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
|
||||
)
|
||||
|
||||
stage_labels.append(stage_idx)
|
||||
progress_targets.append(cumulative_progress)
|
||||
soft_labels_list.append(soft_stage_labels)
|
||||
if soft_stage_labels is not None:
|
||||
has_any_soft_labels = True
|
||||
|
||||
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
|
||||
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
|
||||
|
||||
return stage_labels, progress_targets
|
||||
# Convert soft labels to tensor if any exist
|
||||
soft_stage_labels_tensor = None
|
||||
if has_any_soft_labels:
|
||||
soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32)
|
||||
for i, soft_dict in enumerate(soft_labels_list):
|
||||
if soft_dict is not None:
|
||||
for stage_idx, prob in soft_dict.items():
|
||||
soft_stage_labels_tensor[i, stage_idx] = prob
|
||||
else:
|
||||
# Use hard one-hot label
|
||||
soft_stage_labels_tensor[i, stage_labels[i]] = 1.0
|
||||
|
||||
return stage_labels, progress_targets, soft_stage_labels_tensor
|
||||
|
||||
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
|
||||
"""Generate stage labels and refined progress targets from subtask annotations.
|
||||
"""Generate stage labels, progress targets, and soft stage labels from subtask annotations.
|
||||
|
||||
Args:
|
||||
frame_index: Current frame index or tensor of indices
|
||||
@@ -277,10 +355,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
video_features: Video features tensor to determine sequence length
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations.
|
||||
Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations.
|
||||
- stage_labels: (B, T) hard stage indices
|
||||
- progress_targets: (B, T, 1) progress values
|
||||
- soft_stage_labels: (B, T, num_stages) soft probability labels, or None
|
||||
"""
|
||||
if self.temporal_proportions is None or episode_index is None:
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
# Normalize inputs to numpy arrays
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
@@ -293,21 +374,48 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
seq_len = 1
|
||||
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
all_stage_labels = []
|
||||
all_progress_targets = []
|
||||
all_soft_stage_labels = []
|
||||
has_any_soft_labels = False
|
||||
|
||||
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
|
||||
result = self._compute_labels_for_sample(int(frame_idx), int(ep_idx), seq_len, episodes_df)
|
||||
stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample(
|
||||
int(frame_idx), int(ep_idx), seq_len, episodes_df
|
||||
)
|
||||
|
||||
if result[0] is None:
|
||||
if stage_labels is None:
|
||||
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
|
||||
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
|
||||
all_soft_stage_labels.append(None)
|
||||
else:
|
||||
all_stage_labels.append(result[0])
|
||||
all_progress_targets.append(result[1])
|
||||
all_stage_labels.append(stage_labels)
|
||||
all_progress_targets.append(progress_targets)
|
||||
all_soft_stage_labels.append(soft_labels)
|
||||
if soft_labels is not None:
|
||||
has_any_soft_labels = True
|
||||
|
||||
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
|
||||
stacked_stage_labels = torch.stack(all_stage_labels, dim=0)
|
||||
stacked_progress_targets = torch.stack(all_progress_targets, dim=0)
|
||||
|
||||
# Stack soft labels if any exist
|
||||
stacked_soft_labels = None
|
||||
if has_any_soft_labels:
|
||||
soft_labels_tensors = []
|
||||
for i, soft_labels in enumerate(all_soft_stage_labels):
|
||||
if soft_labels is not None:
|
||||
soft_labels_tensors.append(soft_labels)
|
||||
else:
|
||||
# Create one-hot from hard labels
|
||||
one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32)
|
||||
for t in range(seq_len):
|
||||
one_hot[t, all_stage_labels[i][t]] = 1.0
|
||||
soft_labels_tensors.append(one_hot)
|
||||
stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0)
|
||||
|
||||
return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Encode images, text, and normalize states in the transition."""
|
||||
@@ -371,14 +479,16 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
observation['remaining_length'] = remaining
|
||||
observation['episode_length'] = ep_lengths
|
||||
|
||||
# Generate stage labels and progress targets from subtask annotations
|
||||
# Generate stage labels, progress targets, and soft stage labels from subtask annotations
|
||||
if self.temporal_proportions is not None and self.dataset_meta is not None:
|
||||
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
|
||||
stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels(
|
||||
frame_index, episode_index, video_features
|
||||
)
|
||||
if stage_labels is not None:
|
||||
observation['stage_labels'] = stage_labels
|
||||
observation['progress_targets'] = progress_targets
|
||||
if soft_stage_labels is not None:
|
||||
observation['soft_stage_labels'] = soft_stage_labels
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
Reference in New Issue
Block a user