mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 16:47:14 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e36b0368d4 | |||
| 67b18d87b2 | |||
| 98052e5f6e | |||
| f59260f4aa | |||
| fc262fbc06 |
@@ -519,6 +519,13 @@ def compute_episode_stats(
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
# Features with zero-width shapes are skipped (no data to compute stats on)
|
||||
if any(d == 0 for d in features[key].get("shape", ())):
|
||||
logging.debug(
|
||||
f"Skipping statistics computation for feature '{key}' with a zero-width shape {features[key]['shape']}."
|
||||
)
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
|
||||
@@ -67,9 +67,9 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
# pyarrow rejects fixed-size lists of length 0, so use a variable length list instead
|
||||
length = ft["shape"][0] if ft["shape"][0] > 0 else -1
|
||||
hf_features[key] = datasets.Sequence(length=length, feature=datasets.Value(dtype=ft["dtype"]))
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@@ -687,6 +688,28 @@ def test_compute_episode_stats_string_features_skipped():
|
||||
assert "q01" in stats["action"]
|
||||
|
||||
|
||||
def test_compute_episode_stats_zero_width_features_skipped(caplog):
|
||||
"""Test that features with a zero-width dim (e.g. shape=(0,)) are skipped with a debug log."""
|
||||
episode_data = {
|
||||
"empty": np.zeros((100, 0), dtype=np.float32), # Zero-width feature
|
||||
"action": np.random.normal(0, 1, (100, 5)),
|
||||
}
|
||||
features = {
|
||||
"empty": {"dtype": "float32", "shape": (0,)},
|
||||
"action": {"dtype": "float32", "shape": (5,)},
|
||||
}
|
||||
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
# Zero-width features should be skipped with a debug log, others computed as usual
|
||||
assert "empty" not in stats
|
||||
assert "empty" in caplog.text
|
||||
assert "action" in stats
|
||||
assert "q01" in stats["action"]
|
||||
assert stats["action"]["mean"].shape == (5,)
|
||||
|
||||
|
||||
def test_aggregate_feature_stats_with_quantiles():
|
||||
"""Test aggregating feature stats that include quantiles."""
|
||||
stats_ft_list = [
|
||||
|
||||
@@ -1804,3 +1804,11 @@ def test_episode_filter_unknown_key_raises(tmp_path, lerobot_dataset_factory):
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["not_a_real_field"] > 0,
|
||||
)
|
||||
|
||||
|
||||
def test_get_hf_features_zero_width_feature_does_not_raise_on_from_dict():
|
||||
import datasets
|
||||
|
||||
features = {"empty": {"dtype": "float32", "shape": (0,), "names": ["empty"]}}
|
||||
hf_features = get_hf_features_from_features(features)
|
||||
datasets.Dataset.from_dict({"empty": [[], []]}, features=hf_features)
|
||||
|
||||
Reference in New Issue
Block a user