fix rewind discrepancies

This commit is contained in:
Pepijn
2025-11-18 16:09:16 +01:00
parent 0d84f4724d
commit 52b080fd8c
4 changed files with 61 additions and 57 deletions
@@ -342,7 +342,7 @@ class VideoAnnotator:
file_path, file_path,
start_timestamp=start_timestamp, start_timestamp=start_timestamp,
end_timestamp=end_timestamp, end_timestamp=end_timestamp,
target_fps=2 # 2 FPS is good balance for VLM target_fps=1
) )
is_extracted = extracted_path != file_path is_extracted = extracted_path != file_path
+47 -41
View File
@@ -33,9 +33,7 @@ def sample_video_feature(
video_feature: torch.Tensor, video_feature: torch.Tensor,
max_length: int = 32, max_length: int = 32,
random_sample: bool = True, random_sample: bool = True,
remaining_length: int = None, remaining_length: int = None
absolute_indices: torch.Tensor = None,
episode_length: int = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Sample or pad video features to a fixed length with progress targets. Sample or pad video features to a fixed length with progress targets.
@@ -52,18 +50,13 @@ def sample_video_feature(
This ensures all sequences show increasing progress from near-zero, regardless This ensures all sequences show increasing progress from near-zero, regardless
of where they're sampled from in the episode. of where they're sampled from in the episode.
Note: ReWiND uses consecutive frames loaded via observation_delta_indices. Uses original ReWiND sampling: random start/end points with minimum 3 frames.
When video_length > max_length, this function can subsample, but ReWiND
typically loads exactly max_length frames, so no subsampling occurs.
Args: Args:
video_feature: Video features tensor (num_frames, feature_dim) video_feature: Video features tensor (num_frames, feature_dim)
max_length: Target sequence length max_length: Target sequence length
random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames. random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames.
ReWiND uses False to preserve temporal order.
remaining_length: Remaining trajectory length from first frame to episode end remaining_length: Remaining trajectory length from first frame to episode end
absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback]
episode_length: Total length of the episode [for fallback]
Returns: Returns:
Tuple of: Tuple of:
@@ -72,21 +65,29 @@ def sample_video_feature(
""" """
video_length = len(video_feature) video_length = len(video_feature)
# Original ReWiND sampling: random start/end with minimum 3 frames
if video_length > 3:
# Sample random start index (ensuring we can get at least 3 frames)
start_idx = random.randint(0, max(0, video_length - 3))
# Sample random end index (at least 3 frames after start, up to video_length)
end_idx = random.randint(min(start_idx + 3, video_length), video_length)
# Extract the sampled segment
video_feature = video_feature[start_idx:end_idx]
# Update video_length for the sampled segment
video_length = len(video_feature)
# Adjust remaining_length to be from start_idx to episode end
if remaining_length is not None:
# The remaining length should be from start_idx to episode end
# If we started at start_idx, we've already consumed start_idx frames
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
# Generate progress targets using ORIGINAL ReWiND formula # Generate progress targets using ORIGINAL ReWiND formula
# Progress = (position_in_sequence + 1) / remaining_trajectory_length # Progress = (position_in_sequence + 1) / remaining_trajectory_length
if remaining_length is not None: progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
# CORRECT: Use remaining length from first frame to episode end progress_targets = progress_indices / remaining_length
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
progress_targets = progress_indices / remaining_length
elif absolute_indices is not None and episode_length is not None:
# Fallback: Compute remaining length from first frame to episode end
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
remaining_length_computed = episode_length - first_frame_idx
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
progress_targets = progress_indices / remaining_length_computed
else:
# Fallback: linear progress (for inference/testing)
progress_targets = torch.linspace(1.0/video_length, 1.0, video_length)
if video_length < max_length: if video_length < max_length:
# Pad with last frame # Pad with last frame
@@ -117,15 +118,13 @@ def sample_reverse_video_feature(
video_feature: torch.Tensor, video_feature: torch.Tensor,
max_length: int = 32, max_length: int = 32,
random_sample: bool = True, random_sample: bool = True,
remaining_length: int = None, remaining_length: int = None
absolute_indices: torch.Tensor = None,
episode_length: int = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC. Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC.
This implements the EXACT video rewind augmentation from the original ReWiND paper: This implements the EXACT video rewind augmentation from the original ReWiND paper:
1. Take forward sequence 1. Take forward sequence (sampled with random start/end, min 3 frames)
2. Append reversed frames from the END backwards 2. Append reversed frames from the END backwards
3. Progress increases then decreases (simulating task completion then failure) 3. Progress increases then decreases (simulating task completion then failure)
@@ -139,8 +138,6 @@ def sample_reverse_video_feature(
max_length: Target sequence length max_length: Target sequence length
random_sample: If True, use random sampling for frame selection random_sample: If True, use random sampling for frame selection
remaining_length: Remaining trajectory length from first frame to episode end remaining_length: Remaining trajectory length from first frame to episode end
absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback]
episode_length: Total length of the episode [for fallback]
Returns: Returns:
Tuple of: Tuple of:
@@ -149,21 +146,30 @@ def sample_reverse_video_feature(
""" """
video_length = len(video_feature) video_length = len(video_feature)
# Original logic: start from first half, end in second half, ensure min 3 frames
if video_length > 3:
# Sample start from first half
start_idx = random.randint(0, video_length // 2)
# Sample end from second half
end_idx = random.randint(video_length // 2, video_length)
# Ensure minimum 3 frames difference (original uses while loop)
while end_idx - start_idx < 3:
start_idx = random.randint(0, video_length // 2)
end_idx = random.randint(video_length // 2, video_length)
# Extract the forward segment
video_feature = video_feature[start_idx:end_idx]
video_length = len(video_feature)
# Adjust remaining_length
if remaining_length is not None:
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
# Generate forward progress targets using ORIGINAL ReWiND formula # Generate forward progress targets using ORIGINAL ReWiND formula
# Progress = (position_in_sequence + 1) / remaining_trajectory_length # Progress = (position_in_sequence + 1) / remaining_trajectory_length
if remaining_length is not None: progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
# CORRECT: Use remaining length from first frame to episode end forward_progress = progress_indices / remaining_length
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
forward_progress = progress_indices / remaining_length
elif absolute_indices is not None and episode_length is not None:
# Fallback: Compute remaining length from first frame to episode end
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
remaining_length_computed = episode_length - first_frame_idx
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
forward_progress = progress_indices / remaining_length_computed
else:
# Fallback: linear progress
forward_progress = torch.linspace(1.0/video_length, 1.0, video_length)
# ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence # ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence
# Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first) # Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first)
@@ -39,10 +39,11 @@ class ReWiNDConfig(PreTrainedConfig):
num_layers: int = 4 num_layers: int = 4
# Temporal parameters # Temporal parameters
max_length: int = 32 # Maximum video sequence length max_length: int = 32 # Maximum video sequence length, ORIGINAL: 16!
subsample_video: bool = True # Whether to pad/subsample videos to max_length subsample_video: bool = True # Whether to pad/subsample videos to max_length
use_temporal_sampler: bool = True # Always enable temporal sequence loading use_temporal_sampler: bool = True # Always enable temporal sequence loading
sequence_stride: int = 1 # Stride between frames when using temporal sampler sequence_stride: int = 1 # Stride between frames when using temporal sampler
rewind_ratio: float = 0.8 # Probability of applying rewind augmentation (original: 0.8)
# Training parameters # Training parameters
batch_size: int = 64 batch_size: int = 64
@@ -91,7 +92,7 @@ class ReWiNDConfig(PreTrainedConfig):
def get_optimizer_preset(self) -> AdamWConfig: def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for ReWiND training.""" """Get default optimizer configuration for ReWiND training."""
return AdamWConfig( return AdamWConfig(
lr=3e-4, lr=1e-4,
weight_decay=1e-4, weight_decay=1e-4,
betas=(0.9, 0.999), betas=(0.9, 0.999),
eps=1e-8, eps=1e-8,
@@ -100,8 +101,8 @@ class ReWiNDConfig(PreTrainedConfig):
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Get default learning rate scheduler configuration.""" """Get default learning rate scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig( return CosineDecayWithWarmupSchedulerConfig(
peak_lr=3e-4, peak_lr=1e-4,
decay_lr=3e-5, decay_lr=1e-5,
num_warmup_steps=1000, num_warmup_steps=1000,
num_decay_steps=100000, num_decay_steps=100000,
) )
+8 -11
View File
@@ -569,15 +569,13 @@ class ReWiNDRewardModel(PreTrainedPolicy):
else: else:
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
if random.random() < 0.5: # 50% chance of rewind if random.random() < self.config.rewind_ratio: # Use configurable rewind ratio
# Apply video rewind augmentation (now returns tuple) # Apply video rewind augmentation (now returns tuple)
rewound_video, progress = sample_reverse_video_feature( rewound_video, progress = sample_reverse_video_feature(
video_features[i], video_features[i],
max_length=max_length, max_length=max_length,
random_sample=False, # Use consecutive frames, not random sampling random_sample=True, # Use random sampling (original ReWiND)
remaining_length=current_remaining_length, remaining_length=current_remaining_length
absolute_indices=current_absolute_indices,
episode_length=current_episode_length
) )
processed_videos.append(rewound_video.to(self.device)) processed_videos.append(rewound_video.to(self.device))
progress_targets.append(progress.to(self.device)) progress_targets.append(progress.to(self.device))
@@ -586,10 +584,8 @@ class ReWiNDRewardModel(PreTrainedPolicy):
sampled_video, progress = sample_video_feature( sampled_video, progress = sample_video_feature(
video_features[i], video_features[i],
max_length=max_length, max_length=max_length,
random_sample=False, # Use consecutive frames, not random sampling random_sample=True, # Use random sampling (original ReWiND)
remaining_length=current_remaining_length, remaining_length=current_remaining_length
absolute_indices=current_absolute_indices,
episode_length=current_episode_length
) )
processed_videos.append(sampled_video.to(self.device)) processed_videos.append(sampled_video.to(self.device))
progress_targets.append(progress.to(self.device)) progress_targets.append(progress.to(self.device))
@@ -623,12 +619,13 @@ class ReWiNDRewardModel(PreTrainedPolicy):
# For misaligned pairs, we don't need correct progress targets (will be set to 0) # For misaligned pairs, we don't need correct progress targets (will be set to 0)
misaligned_videos_sampled = [] misaligned_videos_sampled = []
for i in range(batch_size): for i in range(batch_size):
# For misaligned videos, use video length as remaining_length
video_len = len(misaligned_videos[i])
sampled, _ = sample_video_feature( sampled, _ = sample_video_feature(
misaligned_videos[i], misaligned_videos[i],
max_length=max_length, max_length=max_length,
random_sample=True, random_sample=True,
absolute_indices=None, remaining_length=video_len # Use video length for misaligned pairs
episode_length=None
) )
misaligned_videos_sampled.append(sampled.to(self.device)) misaligned_videos_sampled.append(sampled.to(self.device))
misaligned_videos_sampled = torch.stack(misaligned_videos_sampled) misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)