diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index 352438674..4119398d4 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -286,10 +286,17 @@ def run_inference( state_slices = [] for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"): - # Compute frame indices: [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t] - # The first delta is -100000 which clamps to episode start + # Compute frame indices using symmetric bidirectional pattern: + # [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap] + # Boundary handling: clamp to [0, last_valid] deltas = model.config.observation_delta_indices - frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas] + last_valid = len(video_embeddings) - 1 + + frame_indices = [] + for delta in deltas: + idx = current_frame + delta + idx = max(0, min(idx, last_valid)) # Clamp to valid range + frame_indices.append(idx) video_slice = video_embeddings[frame_indices] video_slices.append(video_slice) @@ -324,9 +331,12 @@ def run_inference( batch_video, batch_text, batch_states ) - # Extract last frame predictions (the "current" frame) - batch_progress = progress_preds[:, -1, 0].cpu().numpy() - batch_stages = stage_probs[:, -1, :].cpu().numpy() + # Extract predictions at the "current frame" position + # With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g], + # the current frame is at position 5 (0-indexed) + current_frame_idx = 5 + batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy() + batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy() all_progress.extend(batch_progress) all_stages.extend(batch_stages) diff --git a/src/lerobot/datasets/temporal_sampler.py b/src/lerobot/datasets/temporal_sampler.py index 895e96295..c4d9e5030 100644 --- a/src/lerobot/datasets/temporal_sampler.py +++ b/src/lerobot/datasets/temporal_sampler.py @@ -17,8 +17,11 @@ """ SARM Temporal Sampler for reward model training. -Samples frames from episodes ensuring sufficient temporal history for SARM's -9-frame pattern (1 initial + 8 consecutive with frame_gap spacing). +Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern: +- 1 initial frame + 4 frames before + current + 3 frames after + +Boundary handling: clamp to first/last frame when indices go out of bounds. +This enables truly uniform sampling across entire episodes. """ import logging @@ -31,14 +34,18 @@ import random class SARMTemporalSampler(Sampler): """ - Temporal sampler for SARM reward model training. + Temporal sampler for SARM reward model training with symmetric/bidirectional sampling. SARM uses 9 frames per sample: - Frame 0: Initial frame of the episode (always frame 0) - - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame + - Frames 1-8: Symmetric context around current frame + Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap] - This sampler ensures we only sample from positions that have enough - temporal history (at least 7 * frame_gap frames from episode start). + Boundary handling: + - Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125]) + - Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000]) + + This enables truly uniform sampling across entire episodes. Args: dataset_from_index: Start indices of episodes (global dataset indices) @@ -47,6 +54,7 @@ class SARMTemporalSampler(Sampler): shuffle: Whether to shuffle sampling order seed: Random seed for reproducibility samples_per_epoch: Number of samples per epoch (default: 6400) + min_episode_length: Minimum episode length to include (default: 1) """ def __init__( @@ -57,15 +65,14 @@ class SARMTemporalSampler(Sampler): shuffle: bool = True, seed: Optional[int] = None, samples_per_epoch: int = 6400, + min_episode_length: int = 1, ): self.dataset_from_index = np.array(dataset_from_index) self.dataset_to_index = np.array(dataset_to_index) self.frame_gap = frame_gap self.shuffle = shuffle self.samples_per_epoch = samples_per_epoch - - # Minimum frames needed for SARM pattern: 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1 - self.min_frames_needed = 7 * frame_gap + 1 + self.min_episode_length = min_episode_length if seed is not None: self.seed = seed @@ -75,18 +82,23 @@ class SARMTemporalSampler(Sampler): else: self.generator = torch.Generator() - # Compute valid episodes and sampling positions + # Compute valid episodes and sampling positions (ALL frames for uniform sampling) self._compute_valid_positions() logging.info( f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, " - f"{len(self.all_valid_positions)} valid positions, " + f"{len(self.all_valid_positions)} positions (uniform sampling), " f"{self.samples_per_epoch} samples per epoch, " - f"frame_gap={frame_gap}" + f"frame_gap={frame_gap}, symmetric bidirectional pattern" ) def _compute_valid_positions(self): - """Compute valid episodes and all valid sampling positions.""" + """Compute valid episodes and ALL sampling positions for uniform sampling. + + With symmetric bidirectional sampling, we can sample from ANY frame: + - Early frames: backward indices clamp to first frame + - Late frames: forward indices clamp to last frame + """ self.valid_episodes = [] self.all_valid_positions = [] @@ -95,13 +107,12 @@ class SARMTemporalSampler(Sampler): ep_end = self.dataset_to_index[ep_idx] episode_length = ep_end - ep_start - # Episode must have enough frames for SARM pattern - if episode_length >= self.min_frames_needed: + # Include all episodes with at least min_episode_length frames + if episode_length >= self.min_episode_length: self.valid_episodes.append((ep_idx, ep_start, ep_end)) - # Valid positions: from min_frames_needed to episode end - # These are global dataset indices - for pos in range(ep_start + self.min_frames_needed - 1, ep_end): + # Include ALL positions in the episode (truly uniform sampling) + for pos in range(ep_start, ep_end): self.all_valid_positions.append(pos) self.valid_episodes = np.array(self.valid_episodes) @@ -110,8 +121,7 @@ class SARMTemporalSampler(Sampler): if len(self.all_valid_positions) == 0: raise ValueError( f"No valid sampling positions found! " - f"Episodes need at least {self.min_frames_needed} frames " - f"(7 * frame_gap + 1 = 7 * {self.frame_gap} + 1)." + f"Check that episodes have at least {self.min_episode_length} frames." ) def __len__(self) -> int: @@ -119,12 +129,15 @@ class SARMTemporalSampler(Sampler): def __iter__(self) -> Iterator[int]: """ - Yields global dataset indices for sampling. + Yields global dataset indices for uniform sampling across episodes. Each yielded index represents the "current frame" position. The dataset's observation_delta_indices then handles loading: - Frame 0: Episode initial frame (via large negative delta clamping) - - Frames 1-8: Consecutive frames ending at the yielded index + - Frames 1-8: Symmetric context around current frame (with boundary clamping) + + For early frames: backward indices clamp to first frame (progress ~0%) + For late frames: forward indices clamp to last frame (progress ~100%) """ if self.shuffle: # Randomly sample from all valid positions diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index eef461288..bbc10c6a7 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -141,24 +141,38 @@ class SARMConfig(PreTrainedConfig): @property def observation_delta_indices(self) -> list[int]: - """Load frames for SARM temporal sampling. + """Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern. - Per SARM paper (Section A.4), the model uses 9 frames: - - Frame 0: Initial frame of the episode - - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame + The model uses 9 frames with symmetric context around current frame: + - Frame 0: Initial frame of the episode (clamped via large negative delta) + - Frames 1-8: Symmetric context: 4 before + current + 3 after - The first delta uses a large negative offset (-1_000_000) that will be clamped - to the episode start (frame 0) by the dataset loader. This ensures we always - get the initial frame regardless of the current position in the episode. + Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap] + + Boundary handling (done by dataset loader): + - Early frames: backward indices clamp to 0 (first frame) + - Late frames: forward indices clamp to episode end (last frame) + + This enables truly uniform sampling across entire episodes. Returns: - 9 delta indices: [-1_000_000, -(7*gap), -(6*gap), ..., -gap, 0] + 9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap] """ initial_frame_delta = -1_000_000 - num_consecutive = self.num_frames - 1 # 9 - 1 = 8 - consecutive_deltas = list(range(-self.frame_gap * (num_consecutive - 1), 1, self.frame_gap)) # [-210, -180, -150, -120, -90, -60, -30, 0] - return [initial_frame_delta] + consecutive_deltas + # Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames + symmetric_deltas = [ + -4 * self.frame_gap, + -3 * self.frame_gap, + -2 * self.frame_gap, + -1 * self.frame_gap, + 0, # current frame + 1 * self.frame_gap, + 2 * self.frame_gap, + 3 * self.frame_gap, + ] + + return [initial_frame_delta] + symmetric_deltas @property def action_delta_indices(self) -> None: diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 4badb1b86..0a8b03943 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -92,25 +92,36 @@ class SARMEncodingProcessorStep(ProcessorStep): return episode_indices - def _compute_absolute_indices(self, frame_idx: int, ep_start: int, num_frames: int) -> torch.Tensor: - """Compute absolute frame indices for a sequence. + def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor: + """Compute absolute frame indices for symmetric bidirectional pattern. - (per SARM paper Section A.4): - - Frame 0: Initial frame of the episode (ep_start) - - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame - Pattern: [ep_start, t-(7*gap), t-(6*gap), ..., t-gap, t] - + Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap] + + Boundary handling: + - Backward indices clamp to ep_start (first frame) + - Forward indices clamp to ep_end - 1 (last frame) """ indices = [] - indices.append(ep_start) # First frame is the episode's initial frame - - # Remaining frames are consecutive with frame_gap spacing - num_consecutive = num_frames - 1 - for i in range(num_consecutive): - offset = -(num_consecutive - 1 - i) * self.config.frame_gap - idx = max(ep_start, frame_idx + offset) + indices.append(ep_start) # Initial frame is always episode start + + # Symmetric context: 4 before, current, 3 after + num_before = 4 + num_after = 3 + last_valid_frame = ep_end - 1 + + # Frames before current (clamp to first frame) + for i in range(num_before, 0, -1): + idx = max(ep_start, frame_idx - i * self.config.frame_gap) indices.append(idx) - + + # Current frame + indices.append(frame_idx) + + # Frames after current (clamp to last frame) + for i in range(1, num_after + 1): + idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap) + indices.append(idx) + return torch.tensor(indices) def _compute_episode_metadata( @@ -134,7 +145,7 @@ class SARMEncodingProcessorStep(ProcessorStep): ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] episode_lengths.append(ep_end - ep_start) - abs_indices = self._compute_absolute_indices(frame_idx, ep_start, num_frames) + abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames) absolute_indices_list.append(abs_indices) remaining_lengths.append(ep_end - abs_indices[0].item()) @@ -146,8 +157,9 @@ class SARMEncodingProcessorStep(ProcessorStep): subtask_names: list, subtask_start_frames: list, subtask_end_frames: list, - ) -> tuple[int, float]: - """Compute stage index and cumulative progress for a single frame. + transition_smoothing_frames: int = 15, + ) -> tuple[int, float, dict[int, float] | None]: + """Compute stage index, cumulative progress, and soft stage labels for a single frame. Implements SARM Paper Formula (2): y_t = P_{k-1} + ᾱ_k × τ_t @@ -157,19 +169,28 @@ class SARMEncodingProcessorStep(ProcessorStep): - P_{k-1} is cumulative prior (sum of previous subtask proportions) - ᾱ_k is the temporal proportion for subtask k + Additionally computes soft stage labels near transitions to mitigate discrete jumps + in the stage classifier. Near stage boundaries, labels are blended between adjacent + stages to encourage smoother predictions. + Args: current_frame: Frame index relative to episode start subtask_names: List of subtask names for this episode subtask_start_frames: List of subtask start frames subtask_end_frames: List of subtask end frames + transition_smoothing_frames: Number of frames over which to smooth labels near transitions Returns: - Tuple of (stage_idx, cumulative_progress) + Tuple of (stage_idx, cumulative_progress, soft_stage_labels) + - stage_idx: Hard stage index (for compatibility) + - cumulative_progress: Progress value in [0, 1] + - soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition """ # Get temporal proportions as list for compute_cumulative_progress temporal_proportions_list = [ self.temporal_proportions.get(name, 0.0) for name in self.subtask_names ] + num_stages = len(self.subtask_names) # Find which subtask this frame belongs to for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)): @@ -183,14 +204,34 @@ class SARMEncodingProcessorStep(ProcessorStep): # Compute cumulative progress using utility function (Paper Formula 2) cumulative_progress = compute_cumulative_progress_batch( tau, stage_idx, temporal_proportions_list - ) - return stage_idx, cumulative_progress + ) + + # Compute soft stage labels near transitions + soft_stage_labels = None + frames_from_start = current_frame - start_frame + frames_to_end = end_frame - current_frame + + if frames_from_start < transition_smoothing_frames and j > 0: + # Near start of stage - blend with previous stage + blend = frames_from_start / transition_smoothing_frames + prev_name = subtask_names[j - 1] + prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1) + soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend} + + elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1: + # Near end of stage - blend with next stage + blend = frames_to_end / transition_smoothing_frames + next_name = subtask_names[j + 1] + next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1) + soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend} + + return stage_idx, cumulative_progress, soft_stage_labels # No matching subtask found if current_frame < subtask_start_frames[0]: - return 0, 0.0 + return 0, 0.0, None elif current_frame > subtask_end_frames[-1]: - return len(self.subtask_names) - 1, 1.0 + return len(self.subtask_names) - 1, 1.0, None else: # Between subtasks - use previous subtask's end state (tau = 1.0) for j in range(len(subtask_names) - 1): @@ -202,9 +243,9 @@ class SARMEncodingProcessorStep(ProcessorStep): cumulative_progress = compute_cumulative_progress_batch( 1.0, stage_idx, temporal_proportions_list ) - return stage_idx, cumulative_progress + return stage_idx, cumulative_progress, None - return 0, 0.0 + return 0, 0.0, None def _compute_labels_for_sample( self, @@ -212,12 +253,16 @@ class SARMEncodingProcessorStep(ProcessorStep): ep_idx: int, seq_len: int, episodes_df: pd.DataFrame, - ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]: - """Compute stage labels and progress targets for a single sample. + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]: + """Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern. - (per SARM paper Section A.4): - - Frame 0: Initial frame of episode (stage at frame 0, progress at frame 0) - - Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame + Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap] + + Boundary handling: + - Before episode start: clamp to frame 0 (progress ~0%) + - After episode end: clamp to last frame (progress ~100%) + + Soft stage labels are computed near stage transitions to mitigate discrete jumps. Args: frame_idx: The frame index for this sample @@ -226,50 +271,83 @@ class SARMEncodingProcessorStep(ProcessorStep): episodes_df: DataFrame with episode metadata Returns: - Tuple of (stage_labels, progress_targets) tensors with shapes (T,) and (T, 1), - or (None, None) if no valid annotations + Tuple of (stage_labels, progress_targets, soft_stage_labels): + - stage_labels: (T,) hard stage indices + - progress_targets: (T, 1) progress values + - soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby """ # Check if episode has valid annotations if ep_idx >= len(episodes_df): - return None, None + return None, None, None subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): - return None, None + return None, None, None subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + ep_length = ep_end - ep_start + last_valid_frame = ep_length - 1 + + num_stages = len(self.subtask_names) # Generate labels for each frame in the sequence stage_labels = [] progress_targets = [] + soft_labels_list = [] # List of soft label dicts (or None) + has_any_soft_labels = False + + # Symmetric pattern: initial + 4 before + current + 3 after = 9 frames + num_before = 4 + num_after = 3 for i in range(seq_len): if i == 0: # Position 0: Initial frame of the episode current_frame = 0 # Relative to episode start - else: - # Positions 1-8: consecutive frames with frame_gap spacing - num_consecutive = seq_len - 1 - offset = -(num_consecutive - i) * self.config.frame_gap + elif i <= num_before: + # Positions 1-4: frames before current (with clamping to first frame) + offset = -(num_before - i + 1) * self.config.frame_gap current_frame = max(0, frame_idx + offset - ep_start) - + elif i == num_before + 1: + # Position 5: current frame + current_frame = frame_idx - ep_start + else: + # Positions 6-8: frames after current (with clamping to last frame) + offset = (i - num_before - 1) * self.config.frame_gap + current_frame = min(last_valid_frame, frame_idx + offset - ep_start) - stage_idx, cumulative_progress = self._compute_stage_and_progress_for_frame( + stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame( current_frame, subtask_names, subtask_start_frames, subtask_end_frames ) stage_labels.append(stage_idx) progress_targets.append(cumulative_progress) + soft_labels_list.append(soft_stage_labels) + if soft_stage_labels is not None: + has_any_soft_labels = True stage_labels = torch.tensor(stage_labels, dtype=torch.long) progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) - return stage_labels, progress_targets + # Convert soft labels to tensor if any exist + soft_stage_labels_tensor = None + if has_any_soft_labels: + soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32) + for i, soft_dict in enumerate(soft_labels_list): + if soft_dict is not None: + for stage_idx, prob in soft_dict.items(): + soft_stage_labels_tensor[i, stage_idx] = prob + else: + # Use hard one-hot label + soft_stage_labels_tensor[i, stage_labels[i]] = 1.0 + + return stage_labels, progress_targets, soft_stage_labels_tensor def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features): - """Generate stage labels and refined progress targets from subtask annotations. + """Generate stage labels, progress targets, and soft stage labels from subtask annotations. Args: frame_index: Current frame index or tensor of indices @@ -277,10 +355,13 @@ class SARMEncodingProcessorStep(ProcessorStep): video_features: Video features tensor to determine sequence length Returns: - Tuple of (stage_labels, progress_targets) or (None, None) if no annotations. + Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations. + - stage_labels: (B, T) hard stage indices + - progress_targets: (B, T, 1) progress values + - soft_stage_labels: (B, T, num_stages) soft probability labels, or None """ if self.temporal_proportions is None or episode_index is None: - return None, None + return None, None, None # Normalize inputs to numpy arrays frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) @@ -293,21 +374,48 @@ class SARMEncodingProcessorStep(ProcessorStep): seq_len = 1 episodes_df = self.dataset_meta.episodes.to_pandas() + num_stages = len(self.subtask_names) all_stage_labels = [] all_progress_targets = [] + all_soft_stage_labels = [] + has_any_soft_labels = False for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()): - result = self._compute_labels_for_sample(int(frame_idx), int(ep_idx), seq_len, episodes_df) + stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample( + int(frame_idx), int(ep_idx), seq_len, episodes_df + ) - if result[0] is None: + if stage_labels is None: all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long)) all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32)) + all_soft_stage_labels.append(None) else: - all_stage_labels.append(result[0]) - all_progress_targets.append(result[1]) + all_stage_labels.append(stage_labels) + all_progress_targets.append(progress_targets) + all_soft_stage_labels.append(soft_labels) + if soft_labels is not None: + has_any_soft_labels = True - return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0) + stacked_stage_labels = torch.stack(all_stage_labels, dim=0) + stacked_progress_targets = torch.stack(all_progress_targets, dim=0) + + # Stack soft labels if any exist + stacked_soft_labels = None + if has_any_soft_labels: + soft_labels_tensors = [] + for i, soft_labels in enumerate(all_soft_stage_labels): + if soft_labels is not None: + soft_labels_tensors.append(soft_labels) + else: + # Create one-hot from hard labels + one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32) + for t in range(seq_len): + one_hot[t, all_stage_labels[i][t]] = 1.0 + soft_labels_tensors.append(one_hot) + stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0) + + return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels def __call__(self, transition: EnvTransition) -> EnvTransition: """Encode images, text, and normalize states in the transition.""" @@ -371,14 +479,16 @@ class SARMEncodingProcessorStep(ProcessorStep): observation['remaining_length'] = remaining observation['episode_length'] = ep_lengths - # Generate stage labels and progress targets from subtask annotations + # Generate stage labels, progress targets, and soft stage labels from subtask annotations if self.temporal_proportions is not None and self.dataset_meta is not None: - stage_labels, progress_targets = self._generate_stage_and_progress_labels( + stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels( frame_index, episode_index, video_features ) if stage_labels is not None: observation['stage_labels'] = stage_labels observation['progress_targets'] = progress_targets + if soft_stage_labels is not None: + observation['soft_stage_labels'] = soft_stage_labels new_transition[TransitionKey.OBSERVATION] = observation return new_transition