mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -183,10 +183,20 @@ def test_pi05_rtc_inference_with_prev_chunk():
|
|||||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create dataset stats
|
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||||
dataset_stats = {
|
dataset_stats = {
|
||||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
"observation.state": {
|
||||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
"mean": torch.zeros(14),
|
||||||
|
"std": torch.ones(14),
|
||||||
|
"q01": -torch.ones(14),
|
||||||
|
"q99": torch.ones(14),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": torch.zeros(7),
|
||||||
|
"std": torch.ones(7),
|
||||||
|
"q01": -torch.ones(7),
|
||||||
|
"q99": torch.ones(7),
|
||||||
|
},
|
||||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,10 +270,20 @@ def test_pi05_rtc_inference_without_prev_chunk():
|
|||||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create dataset stats
|
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||||
dataset_stats = {
|
dataset_stats = {
|
||||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
"observation.state": {
|
||||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
"mean": torch.zeros(14),
|
||||||
|
"std": torch.ones(14),
|
||||||
|
"q01": -torch.ones(14),
|
||||||
|
"q99": torch.ones(14),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": torch.zeros(7),
|
||||||
|
"std": torch.ones(7),
|
||||||
|
"q01": -torch.ones(7),
|
||||||
|
"q99": torch.ones(7),
|
||||||
|
},
|
||||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,10 +348,20 @@ def test_pi05_rtc_validation_rules():
|
|||||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create dataset stats
|
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||||
dataset_stats = {
|
dataset_stats = {
|
||||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
"observation.state": {
|
||||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
"mean": torch.zeros(14),
|
||||||
|
"std": torch.ones(14),
|
||||||
|
"q01": -torch.ones(14),
|
||||||
|
"q99": torch.ones(14),
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"mean": torch.zeros(7),
|
||||||
|
"std": torch.ones(7),
|
||||||
|
"q01": -torch.ones(7),
|
||||||
|
"q99": torch.ones(7),
|
||||||
|
},
|
||||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user