more changes

This commit is contained in:
Jade Choghari
2026-01-27 11:14:22 +00:00
parent 4c694e20c7
commit 2bf6359d24
3 changed files with 15 additions and 6 deletions
+11 -1
View File
@@ -1070,7 +1070,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
try:
video_frames = self._query_videos(query_timestamps, ep_idx)
except Exception as e:
print("\n" + "=" * 120)
print("[VIDEO DECODE FAILURE]")
print(f"item={item}")
print(f"query_indices={query_indices}")
print(f"query_timestamps={query_timestamps}")
print(f"ep_idx={ep_idx}")
print("=" * 120 + "\n")
raise
item = {**video_frames, **item}
if self.image_transforms is not None:
@@ -61,8 +61,6 @@ class PI05FullConfig(PreTrainedConfig):
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
@@ -104,7 +102,7 @@ class PI05FullConfig(PreTrainedConfig):
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
tokenizer_max_length: int = 48 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
@@ -375,8 +375,9 @@ def compute_layer_complete(
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Store reference instead of clone - we need original for second residual
after_first_residual = out_emb
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
# convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)