mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-27 21:27:21 +00:00
fix
This commit is contained in:
@@ -421,10 +421,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
target = augmented_target[:, idx]
|
||||
else:
|
||||
# Calculate true episode progress using episode_index and frame_index from batch
|
||||
if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"):
|
||||
# Get episode indices and frame indices from batch
|
||||
episode_indices = batch["episode_index"] # Shape: (B,)
|
||||
frame_indices = batch["frame_index"] # Shape: (B,)
|
||||
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
|
||||
if episode_indices is not None and frame_indices is not None and self.episode_data_index is not None:
|
||||
|
||||
# Calculate progress for the current frame in each sample
|
||||
progress_values = []
|
||||
@@ -485,7 +483,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
target = target[:, idx]
|
||||
else:
|
||||
raise ValueError(
|
||||
"No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present."
|
||||
"No episode information found to build full-episode progress. "
|
||||
"Expected 'episode_index' and 'frame_index' in batch and a valid 'episode_data_index' on the policy. "
|
||||
"Please pass RLearNPolicy(episode_data_index=...) built from episodes.jsonl (per-episode lengths), "
|
||||
"and ensure the dataset exposes 'episode_index' and 'frame_index' (shape (B,) or (B,1))."
|
||||
)
|
||||
|
||||
# During inference, we might not want to compute loss
|
||||
@@ -564,6 +565,42 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data.
|
||||
|
||||
Accepts shapes (B,) or (B,1) and returns 1D long tensors on the model device.
|
||||
"""
|
||||
device = next(self.parameters()).device
|
||||
|
||||
ep = batch.get("episode_index")
|
||||
fr = batch.get("frame_index")
|
||||
|
||||
# Try complementary_data
|
||||
if (ep is None or fr is None) and isinstance(batch.get("complementary_data"), dict):
|
||||
comp = batch["complementary_data"]
|
||||
ep = comp.get("episode_index", ep)
|
||||
fr = comp.get("frame_index", fr)
|
||||
|
||||
if ep is None or fr is None:
|
||||
return None, None
|
||||
|
||||
# Convert to 1D long tensors on device
|
||||
if torch.is_tensor(ep):
|
||||
if ep.dim() == 2 and ep.shape[1] == 1:
|
||||
ep = ep.squeeze(1)
|
||||
ep = ep.to(device=device, dtype=torch.long)
|
||||
else:
|
||||
ep = torch.as_tensor(ep, device=device, dtype=torch.long)
|
||||
|
||||
if torch.is_tensor(fr):
|
||||
if fr.dim() == 2 and fr.shape[1] == 1:
|
||||
fr = fr.squeeze(1)
|
||||
fr = fr.to(device=device, dtype=torch.long)
|
||||
else:
|
||||
fr = torch.as_tensor(fr, device=device, dtype=torch.long)
|
||||
|
||||
return ep, fr
|
||||
|
||||
|
||||
# Helper functions for ReWiND architecture
|
||||
|
||||
|
||||
Reference in New Issue
Block a user