feat(datasets): warn when skipping stats for zero-width features

Per review, log a warning when compute_episode_stats skips a feature with a
zero-width shape, so users know stats were intentionally not computed for it.
This commit is contained in:
Mahbod
2026-06-12 12:14:32 +02:00
committed by CarolinePascal
parent f59260f4aa
commit 98052e5f6e
2 changed files with 10 additions and 7 deletions
+3
View File
@@ -520,6 +520,9 @@ def compute_episode_stats(
continue
if any(d == 0 for d in features[key].get("shape", ())):
logging.warning(
f"Skipping stats for feature '{key}' with a zero-width shape {features[key]['shape']}."
)
continue
if features[key]["dtype"] in ["image", "video"]:
+7 -7
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import patch
import numpy as np
@@ -687,8 +688,8 @@ def test_compute_episode_stats_string_features_skipped():
assert "q01" in stats["action"]
def test_compute_episode_stats_zero_width_features_skipped():
"""Test that features with a zero-width dim (e.g. shape=(0,)) are skipped."""
def test_compute_episode_stats_zero_width_features_skipped(caplog):
"""Test that features with a zero-width dim (e.g. shape=(0,)) are skipped with a warning."""
episode_data = {
"empty": np.zeros((100, 0), dtype=np.float32), # Zero-width feature
"action": np.random.normal(0, 1, (100, 5)),
@@ -698,13 +699,12 @@ def test_compute_episode_stats_zero_width_features_skipped():
"action": {"dtype": "float32", "shape": (5,)},
}
stats = compute_episode_stats(
episode_data,
features,
)
with caplog.at_level(logging.WARNING):
stats = compute_episode_stats(episode_data, features)
# Zero-width features should be skipped
# Zero-width features should be skipped with a warning, others computed as usual
assert "empty" not in stats
assert "empty" in caplog.text
assert "action" in stats
assert "q01" in stats["action"]
assert stats["action"]["mean"].shape == (5,)