mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
fix rewind discrepancies
This commit is contained in:
@@ -342,7 +342,7 @@ class VideoAnnotator:
|
|||||||
file_path,
|
file_path,
|
||||||
start_timestamp=start_timestamp,
|
start_timestamp=start_timestamp,
|
||||||
end_timestamp=end_timestamp,
|
end_timestamp=end_timestamp,
|
||||||
target_fps=2 # 2 FPS is good balance for VLM
|
target_fps=1
|
||||||
)
|
)
|
||||||
is_extracted = extracted_path != file_path
|
is_extracted = extracted_path != file_path
|
||||||
|
|
||||||
|
|||||||
@@ -33,9 +33,7 @@ def sample_video_feature(
|
|||||||
video_feature: torch.Tensor,
|
video_feature: torch.Tensor,
|
||||||
max_length: int = 32,
|
max_length: int = 32,
|
||||||
random_sample: bool = True,
|
random_sample: bool = True,
|
||||||
remaining_length: int = None,
|
remaining_length: int = None
|
||||||
absolute_indices: torch.Tensor = None,
|
|
||||||
episode_length: int = None
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Sample or pad video features to a fixed length with progress targets.
|
Sample or pad video features to a fixed length with progress targets.
|
||||||
@@ -52,18 +50,13 @@ def sample_video_feature(
|
|||||||
This ensures all sequences show increasing progress from near-zero, regardless
|
This ensures all sequences show increasing progress from near-zero, regardless
|
||||||
of where they're sampled from in the episode.
|
of where they're sampled from in the episode.
|
||||||
|
|
||||||
Note: ReWiND uses consecutive frames loaded via observation_delta_indices.
|
Uses original ReWiND sampling: random start/end points with minimum 3 frames.
|
||||||
When video_length > max_length, this function can subsample, but ReWiND
|
|
||||||
typically loads exactly max_length frames, so no subsampling occurs.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_feature: Video features tensor (num_frames, feature_dim)
|
video_feature: Video features tensor (num_frames, feature_dim)
|
||||||
max_length: Target sequence length
|
max_length: Target sequence length
|
||||||
random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames.
|
random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames.
|
||||||
ReWiND uses False to preserve temporal order.
|
|
||||||
remaining_length: Remaining trajectory length from first frame to episode end
|
remaining_length: Remaining trajectory length from first frame to episode end
|
||||||
absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback]
|
|
||||||
episode_length: Total length of the episode [for fallback]
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of:
|
Tuple of:
|
||||||
@@ -72,21 +65,29 @@ def sample_video_feature(
|
|||||||
"""
|
"""
|
||||||
video_length = len(video_feature)
|
video_length = len(video_feature)
|
||||||
|
|
||||||
|
# Original ReWiND sampling: random start/end with minimum 3 frames
|
||||||
|
if video_length > 3:
|
||||||
|
# Sample random start index (ensuring we can get at least 3 frames)
|
||||||
|
start_idx = random.randint(0, max(0, video_length - 3))
|
||||||
|
# Sample random end index (at least 3 frames after start, up to video_length)
|
||||||
|
end_idx = random.randint(min(start_idx + 3, video_length), video_length)
|
||||||
|
|
||||||
|
# Extract the sampled segment
|
||||||
|
video_feature = video_feature[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Update video_length for the sampled segment
|
||||||
|
video_length = len(video_feature)
|
||||||
|
|
||||||
|
# Adjust remaining_length to be from start_idx to episode end
|
||||||
|
if remaining_length is not None:
|
||||||
|
# The remaining length should be from start_idx to episode end
|
||||||
|
# If we started at start_idx, we've already consumed start_idx frames
|
||||||
|
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
|
||||||
|
|
||||||
# Generate progress targets using ORIGINAL ReWiND formula
|
# Generate progress targets using ORIGINAL ReWiND formula
|
||||||
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
||||||
if remaining_length is not None:
|
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
||||||
# CORRECT: Use remaining length from first frame to episode end
|
progress_targets = progress_indices / remaining_length
|
||||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
|
||||||
progress_targets = progress_indices / remaining_length
|
|
||||||
elif absolute_indices is not None and episode_length is not None:
|
|
||||||
# Fallback: Compute remaining length from first frame to episode end
|
|
||||||
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
|
|
||||||
remaining_length_computed = episode_length - first_frame_idx
|
|
||||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
|
||||||
progress_targets = progress_indices / remaining_length_computed
|
|
||||||
else:
|
|
||||||
# Fallback: linear progress (for inference/testing)
|
|
||||||
progress_targets = torch.linspace(1.0/video_length, 1.0, video_length)
|
|
||||||
|
|
||||||
if video_length < max_length:
|
if video_length < max_length:
|
||||||
# Pad with last frame
|
# Pad with last frame
|
||||||
@@ -117,15 +118,13 @@ def sample_reverse_video_feature(
|
|||||||
video_feature: torch.Tensor,
|
video_feature: torch.Tensor,
|
||||||
max_length: int = 32,
|
max_length: int = 32,
|
||||||
random_sample: bool = True,
|
random_sample: bool = True,
|
||||||
remaining_length: int = None,
|
remaining_length: int = None
|
||||||
absolute_indices: torch.Tensor = None,
|
|
||||||
episode_length: int = None
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC.
|
Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC.
|
||||||
|
|
||||||
This implements the EXACT video rewind augmentation from the original ReWiND paper:
|
This implements the EXACT video rewind augmentation from the original ReWiND paper:
|
||||||
1. Take forward sequence
|
1. Take forward sequence (sampled with random start/end, min 3 frames)
|
||||||
2. Append reversed frames from the END backwards
|
2. Append reversed frames from the END backwards
|
||||||
3. Progress increases then decreases (simulating task completion then failure)
|
3. Progress increases then decreases (simulating task completion then failure)
|
||||||
|
|
||||||
@@ -139,8 +138,6 @@ def sample_reverse_video_feature(
|
|||||||
max_length: Target sequence length
|
max_length: Target sequence length
|
||||||
random_sample: If True, use random sampling for frame selection
|
random_sample: If True, use random sampling for frame selection
|
||||||
remaining_length: Remaining trajectory length from first frame to episode end
|
remaining_length: Remaining trajectory length from first frame to episode end
|
||||||
absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback]
|
|
||||||
episode_length: Total length of the episode [for fallback]
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of:
|
Tuple of:
|
||||||
@@ -149,21 +146,30 @@ def sample_reverse_video_feature(
|
|||||||
"""
|
"""
|
||||||
video_length = len(video_feature)
|
video_length = len(video_feature)
|
||||||
|
|
||||||
|
# Original logic: start from first half, end in second half, ensure min 3 frames
|
||||||
|
if video_length > 3:
|
||||||
|
# Sample start from first half
|
||||||
|
start_idx = random.randint(0, video_length // 2)
|
||||||
|
# Sample end from second half
|
||||||
|
end_idx = random.randint(video_length // 2, video_length)
|
||||||
|
|
||||||
|
# Ensure minimum 3 frames difference (original uses while loop)
|
||||||
|
while end_idx - start_idx < 3:
|
||||||
|
start_idx = random.randint(0, video_length // 2)
|
||||||
|
end_idx = random.randint(video_length // 2, video_length)
|
||||||
|
|
||||||
|
# Extract the forward segment
|
||||||
|
video_feature = video_feature[start_idx:end_idx]
|
||||||
|
video_length = len(video_feature)
|
||||||
|
|
||||||
|
# Adjust remaining_length
|
||||||
|
if remaining_length is not None:
|
||||||
|
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
|
||||||
|
|
||||||
# Generate forward progress targets using ORIGINAL ReWiND formula
|
# Generate forward progress targets using ORIGINAL ReWiND formula
|
||||||
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
||||||
if remaining_length is not None:
|
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
||||||
# CORRECT: Use remaining length from first frame to episode end
|
forward_progress = progress_indices / remaining_length
|
||||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
|
||||||
forward_progress = progress_indices / remaining_length
|
|
||||||
elif absolute_indices is not None and episode_length is not None:
|
|
||||||
# Fallback: Compute remaining length from first frame to episode end
|
|
||||||
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
|
|
||||||
remaining_length_computed = episode_length - first_frame_idx
|
|
||||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
|
||||||
forward_progress = progress_indices / remaining_length_computed
|
|
||||||
else:
|
|
||||||
# Fallback: linear progress
|
|
||||||
forward_progress = torch.linspace(1.0/video_length, 1.0, video_length)
|
|
||||||
|
|
||||||
# ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence
|
# ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence
|
||||||
# Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first)
|
# Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first)
|
||||||
|
|||||||
@@ -39,10 +39,11 @@ class ReWiNDConfig(PreTrainedConfig):
|
|||||||
num_layers: int = 4
|
num_layers: int = 4
|
||||||
|
|
||||||
# Temporal parameters
|
# Temporal parameters
|
||||||
max_length: int = 32 # Maximum video sequence length
|
max_length: int = 32 # Maximum video sequence length, ORIGINAL: 16!
|
||||||
subsample_video: bool = True # Whether to pad/subsample videos to max_length
|
subsample_video: bool = True # Whether to pad/subsample videos to max_length
|
||||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||||
sequence_stride: int = 1 # Stride between frames when using temporal sampler
|
sequence_stride: int = 1 # Stride between frames when using temporal sampler
|
||||||
|
rewind_ratio: float = 0.8 # Probability of applying rewind augmentation (original: 0.8)
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
@@ -91,7 +92,7 @@ class ReWiNDConfig(PreTrainedConfig):
|
|||||||
def get_optimizer_preset(self) -> AdamWConfig:
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
"""Get default optimizer configuration for ReWiND training."""
|
"""Get default optimizer configuration for ReWiND training."""
|
||||||
return AdamWConfig(
|
return AdamWConfig(
|
||||||
lr=3e-4,
|
lr=1e-4,
|
||||||
weight_decay=1e-4,
|
weight_decay=1e-4,
|
||||||
betas=(0.9, 0.999),
|
betas=(0.9, 0.999),
|
||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
@@ -100,8 +101,8 @@ class ReWiNDConfig(PreTrainedConfig):
|
|||||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||||
"""Get default learning rate scheduler configuration."""
|
"""Get default learning rate scheduler configuration."""
|
||||||
return CosineDecayWithWarmupSchedulerConfig(
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
peak_lr=3e-4,
|
peak_lr=1e-4,
|
||||||
decay_lr=3e-5,
|
decay_lr=1e-5,
|
||||||
num_warmup_steps=1000,
|
num_warmup_steps=1000,
|
||||||
num_decay_steps=100000,
|
num_decay_steps=100000,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -569,15 +569,13 @@ class ReWiNDRewardModel(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
|
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
|
||||||
|
|
||||||
if random.random() < 0.5: # 50% chance of rewind
|
if random.random() < self.config.rewind_ratio: # Use configurable rewind ratio
|
||||||
# Apply video rewind augmentation (now returns tuple)
|
# Apply video rewind augmentation (now returns tuple)
|
||||||
rewound_video, progress = sample_reverse_video_feature(
|
rewound_video, progress = sample_reverse_video_feature(
|
||||||
video_features[i],
|
video_features[i],
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
random_sample=False, # Use consecutive frames, not random sampling
|
random_sample=True, # Use random sampling (original ReWiND)
|
||||||
remaining_length=current_remaining_length,
|
remaining_length=current_remaining_length
|
||||||
absolute_indices=current_absolute_indices,
|
|
||||||
episode_length=current_episode_length
|
|
||||||
)
|
)
|
||||||
processed_videos.append(rewound_video.to(self.device))
|
processed_videos.append(rewound_video.to(self.device))
|
||||||
progress_targets.append(progress.to(self.device))
|
progress_targets.append(progress.to(self.device))
|
||||||
@@ -586,10 +584,8 @@ class ReWiNDRewardModel(PreTrainedPolicy):
|
|||||||
sampled_video, progress = sample_video_feature(
|
sampled_video, progress = sample_video_feature(
|
||||||
video_features[i],
|
video_features[i],
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
random_sample=False, # Use consecutive frames, not random sampling
|
random_sample=True, # Use random sampling (original ReWiND)
|
||||||
remaining_length=current_remaining_length,
|
remaining_length=current_remaining_length
|
||||||
absolute_indices=current_absolute_indices,
|
|
||||||
episode_length=current_episode_length
|
|
||||||
)
|
)
|
||||||
processed_videos.append(sampled_video.to(self.device))
|
processed_videos.append(sampled_video.to(self.device))
|
||||||
progress_targets.append(progress.to(self.device))
|
progress_targets.append(progress.to(self.device))
|
||||||
@@ -623,12 +619,13 @@ class ReWiNDRewardModel(PreTrainedPolicy):
|
|||||||
# For misaligned pairs, we don't need correct progress targets (will be set to 0)
|
# For misaligned pairs, we don't need correct progress targets (will be set to 0)
|
||||||
misaligned_videos_sampled = []
|
misaligned_videos_sampled = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
# For misaligned videos, use video length as remaining_length
|
||||||
|
video_len = len(misaligned_videos[i])
|
||||||
sampled, _ = sample_video_feature(
|
sampled, _ = sample_video_feature(
|
||||||
misaligned_videos[i],
|
misaligned_videos[i],
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
random_sample=True,
|
random_sample=True,
|
||||||
absolute_indices=None,
|
remaining_length=video_len # Use video length for misaligned pairs
|
||||||
episode_length=None
|
|
||||||
)
|
)
|
||||||
misaligned_videos_sampled.append(sampled.to(self.device))
|
misaligned_videos_sampled.append(sampled.to(self.device))
|
||||||
misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)
|
misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)
|
||||||
|
|||||||
Reference in New Issue
Block a user