diff --git a/examples/dataset_annotation/subtask_annotation.py b/examples/dataset_annotation/subtask_annotation.py index ef2532fbd..4493f91e8 100644 --- a/examples/dataset_annotation/subtask_annotation.py +++ b/examples/dataset_annotation/subtask_annotation.py @@ -342,7 +342,7 @@ class VideoAnnotator: file_path, start_timestamp=start_timestamp, end_timestamp=end_timestamp, - target_fps=2 # 2 FPS is good balance for VLM + target_fps=1 ) is_extracted = extracted_path != file_path diff --git a/src/lerobot/datasets/video_sampler.py b/src/lerobot/datasets/video_sampler.py index f33e14b1b..222167cc8 100644 --- a/src/lerobot/datasets/video_sampler.py +++ b/src/lerobot/datasets/video_sampler.py @@ -33,9 +33,7 @@ def sample_video_feature( video_feature: torch.Tensor, max_length: int = 32, random_sample: bool = True, - remaining_length: int = None, - absolute_indices: torch.Tensor = None, - episode_length: int = None + remaining_length: int = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Sample or pad video features to a fixed length with progress targets. @@ -52,18 +50,13 @@ def sample_video_feature( This ensures all sequences show increasing progress from near-zero, regardless of where they're sampled from in the episode. - Note: ReWiND uses consecutive frames loaded via observation_delta_indices. - When video_length > max_length, this function can subsample, but ReWiND - typically loads exactly max_length frames, so no subsampling occurs. + Uses original ReWiND sampling: random start/end points with minimum 3 frames. Args: video_feature: Video features tensor (num_frames, feature_dim) max_length: Target sequence length random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames. - ReWiND uses False to preserve temporal order. remaining_length: Remaining trajectory length from first frame to episode end - absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback] - episode_length: Total length of the episode [for fallback] Returns: Tuple of: @@ -72,21 +65,29 @@ def sample_video_feature( """ video_length = len(video_feature) + # Original ReWiND sampling: random start/end with minimum 3 frames + if video_length > 3: + # Sample random start index (ensuring we can get at least 3 frames) + start_idx = random.randint(0, max(0, video_length - 3)) + # Sample random end index (at least 3 frames after start, up to video_length) + end_idx = random.randint(min(start_idx + 3, video_length), video_length) + + # Extract the sampled segment + video_feature = video_feature[start_idx:end_idx] + + # Update video_length for the sampled segment + video_length = len(video_feature) + + # Adjust remaining_length to be from start_idx to episode end + if remaining_length is not None: + # The remaining length should be from start_idx to episode end + # If we started at start_idx, we've already consumed start_idx frames + remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length + # Generate progress targets using ORIGINAL ReWiND formula # Progress = (position_in_sequence + 1) / remaining_trajectory_length - if remaining_length is not None: - # CORRECT: Use remaining length from first frame to episode end - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - progress_targets = progress_indices / remaining_length - elif absolute_indices is not None and episode_length is not None: - # Fallback: Compute remaining length from first frame to episode end - first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] - remaining_length_computed = episode_length - first_frame_idx - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - progress_targets = progress_indices / remaining_length_computed - else: - # Fallback: linear progress (for inference/testing) - progress_targets = torch.linspace(1.0/video_length, 1.0, video_length) + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + progress_targets = progress_indices / remaining_length if video_length < max_length: # Pad with last frame @@ -117,15 +118,13 @@ def sample_reverse_video_feature( video_feature: torch.Tensor, max_length: int = 32, random_sample: bool = True, - remaining_length: int = None, - absolute_indices: torch.Tensor = None, - episode_length: int = None + remaining_length: int = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC. This implements the EXACT video rewind augmentation from the original ReWiND paper: - 1. Take forward sequence + 1. Take forward sequence (sampled with random start/end, min 3 frames) 2. Append reversed frames from the END backwards 3. Progress increases then decreases (simulating task completion then failure) @@ -139,8 +138,6 @@ def sample_reverse_video_feature( max_length: Target sequence length random_sample: If True, use random sampling for frame selection remaining_length: Remaining trajectory length from first frame to episode end - absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback] - episode_length: Total length of the episode [for fallback] Returns: Tuple of: @@ -149,21 +146,30 @@ def sample_reverse_video_feature( """ video_length = len(video_feature) + # Original logic: start from first half, end in second half, ensure min 3 frames + if video_length > 3: + # Sample start from first half + start_idx = random.randint(0, video_length // 2) + # Sample end from second half + end_idx = random.randint(video_length // 2, video_length) + + # Ensure minimum 3 frames difference (original uses while loop) + while end_idx - start_idx < 3: + start_idx = random.randint(0, video_length // 2) + end_idx = random.randint(video_length // 2, video_length) + + # Extract the forward segment + video_feature = video_feature[start_idx:end_idx] + video_length = len(video_feature) + + # Adjust remaining_length + if remaining_length is not None: + remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length + # Generate forward progress targets using ORIGINAL ReWiND formula # Progress = (position_in_sequence + 1) / remaining_trajectory_length - if remaining_length is not None: - # CORRECT: Use remaining length from first frame to episode end - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - forward_progress = progress_indices / remaining_length - elif absolute_indices is not None and episode_length is not None: - # Fallback: Compute remaining length from first frame to episode end - first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0] - remaining_length_computed = episode_length - first_frame_idx - progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) - forward_progress = progress_indices / remaining_length_computed - else: - # Fallback: linear progress - forward_progress = torch.linspace(1.0/video_length, 1.0, video_length) + progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32) + forward_progress = progress_indices / remaining_length # ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence # Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first) diff --git a/src/lerobot/policies/rewind/configuration_rewind.py b/src/lerobot/policies/rewind/configuration_rewind.py index 2280ca634..9a9b61584 100644 --- a/src/lerobot/policies/rewind/configuration_rewind.py +++ b/src/lerobot/policies/rewind/configuration_rewind.py @@ -39,10 +39,11 @@ class ReWiNDConfig(PreTrainedConfig): num_layers: int = 4 # Temporal parameters - max_length: int = 32 # Maximum video sequence length + max_length: int = 32 # Maximum video sequence length, ORIGINAL: 16! subsample_video: bool = True # Whether to pad/subsample videos to max_length use_temporal_sampler: bool = True # Always enable temporal sequence loading sequence_stride: int = 1 # Stride between frames when using temporal sampler + rewind_ratio: float = 0.8 # Probability of applying rewind augmentation (original: 0.8) # Training parameters batch_size: int = 64 @@ -91,7 +92,7 @@ class ReWiNDConfig(PreTrainedConfig): def get_optimizer_preset(self) -> AdamWConfig: """Get default optimizer configuration for ReWiND training.""" return AdamWConfig( - lr=3e-4, + lr=1e-4, weight_decay=1e-4, betas=(0.9, 0.999), eps=1e-8, @@ -100,8 +101,8 @@ class ReWiNDConfig(PreTrainedConfig): def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: """Get default learning rate scheduler configuration.""" return CosineDecayWithWarmupSchedulerConfig( - peak_lr=3e-4, - decay_lr=3e-5, + peak_lr=1e-4, + decay_lr=1e-5, num_warmup_steps=1000, num_decay_steps=100000, ) diff --git a/src/lerobot/policies/rewind/modeling_rewind.py b/src/lerobot/policies/rewind/modeling_rewind.py index 7adc131aa..81576770a 100644 --- a/src/lerobot/policies/rewind/modeling_rewind.py +++ b/src/lerobot/policies/rewind/modeling_rewind.py @@ -569,15 +569,13 @@ class ReWiNDRewardModel(PreTrainedPolicy): else: current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths - if random.random() < 0.5: # 50% chance of rewind + if random.random() < self.config.rewind_ratio: # Use configurable rewind ratio # Apply video rewind augmentation (now returns tuple) rewound_video, progress = sample_reverse_video_feature( video_features[i], max_length=max_length, - random_sample=False, # Use consecutive frames, not random sampling - remaining_length=current_remaining_length, - absolute_indices=current_absolute_indices, - episode_length=current_episode_length + random_sample=True, # Use random sampling (original ReWiND) + remaining_length=current_remaining_length ) processed_videos.append(rewound_video.to(self.device)) progress_targets.append(progress.to(self.device)) @@ -586,10 +584,8 @@ class ReWiNDRewardModel(PreTrainedPolicy): sampled_video, progress = sample_video_feature( video_features[i], max_length=max_length, - random_sample=False, # Use consecutive frames, not random sampling - remaining_length=current_remaining_length, - absolute_indices=current_absolute_indices, - episode_length=current_episode_length + random_sample=True, # Use random sampling (original ReWiND) + remaining_length=current_remaining_length ) processed_videos.append(sampled_video.to(self.device)) progress_targets.append(progress.to(self.device)) @@ -623,12 +619,13 @@ class ReWiNDRewardModel(PreTrainedPolicy): # For misaligned pairs, we don't need correct progress targets (will be set to 0) misaligned_videos_sampled = [] for i in range(batch_size): + # For misaligned videos, use video length as remaining_length + video_len = len(misaligned_videos[i]) sampled, _ = sample_video_feature( misaligned_videos[i], max_length=max_length, random_sample=True, - absolute_indices=None, - episode_length=None + remaining_length=video_len # Use video length for misaligned pairs ) misaligned_videos_sampled.append(sampled.to(self.device)) misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)