mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
chore(processor): rename merge_features -> combine_feature_dicts (#1856)
This commit is contained in:
@@ -17,7 +17,7 @@
|
|||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||||
from lerobot.datasets.utils import merge_features
|
from lerobot.datasets.utils import combine_feature_dicts
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||||
from lerobot.policies.factory import make_pre_post_processors
|
from lerobot.policies.factory import make_pre_post_processors
|
||||||
@@ -103,7 +103,7 @@ obs_ee = aggregate_pipeline_dataset_features(
|
|||||||
patterns=["observation.state.ee"],
|
patterns=["observation.state.ee"],
|
||||||
) # Get all ee observation features
|
) # Get all ee observation features
|
||||||
|
|
||||||
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
|
dataset_features = combine_feature_dicts(obs_ee, action_ee_and_gripper)
|
||||||
|
|
||||||
print("All dataset features: ", dataset_features)
|
print("All dataset features: ", dataset_features)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||||
from lerobot.datasets.utils import merge_features
|
from lerobot.datasets.utils import combine_feature_dicts
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor import DataProcessorPipeline
|
from lerobot.processor import DataProcessorPipeline
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
@@ -142,7 +142,7 @@ observation_ee = aggregate_pipeline_dataset_features(
|
|||||||
patterns=["observation.state.ee"],
|
patterns=["observation.state.ee"],
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_features = merge_features(action_ee, gripper, observation_ee)
|
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
|
||||||
|
|
||||||
print("All dataset features: ", dataset_features)
|
print("All dataset features: ", dataset_features)
|
||||||
|
|
||||||
|
|||||||
@@ -470,7 +470,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
return policy_features
|
return policy_features
|
||||||
|
|
||||||
|
|
||||||
def merge_features(*dicts: dict) -> dict:
|
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Merge LeRobot grouped feature dicts.
|
Merge LeRobot grouped feature dicts.
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from datasets import Dataset
|
|||||||
from huggingface_hub import DatasetCard
|
from huggingface_hub import DatasetCard
|
||||||
|
|
||||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch, merge_features
|
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||||
|
|
||||||
|
|
||||||
def test_default_parameters():
|
def test_default_parameters():
|
||||||
@@ -72,7 +72,7 @@ def test_merge_simple_vectors():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out = merge_features(g1, g2)
|
out = combine_feature_dicts(g1, g2)
|
||||||
|
|
||||||
assert "action" in out
|
assert "action" in out
|
||||||
assert out["action"]["dtype"] == "float32"
|
assert out["action"]["dtype"] == "float32"
|
||||||
@@ -87,7 +87,7 @@ def test_merge_multiple_groups_order_and_dedup():
|
|||||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||||
|
|
||||||
out = merge_features(g1, g2, g3)
|
out = combine_feature_dicts(g1, g2, g3)
|
||||||
|
|
||||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||||
assert out["action"]["shape"] == (4,)
|
assert out["action"]["shape"] == (4,)
|
||||||
@@ -110,7 +110,7 @@ def test_non_vector_last_wins_for_images():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out = merge_features(g1, g2)
|
out = combine_feature_dicts(g1, g2)
|
||||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||||
assert out["observation.images.front"]["dtype"] == "image"
|
assert out["observation.images.front"]["dtype"] == "image"
|
||||||
|
|
||||||
@@ -120,13 +120,13 @@ def test_dtype_mismatch_raises():
|
|||||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||||
_ = merge_features(g1, g2)
|
_ = combine_feature_dicts(g1, g2)
|
||||||
|
|
||||||
|
|
||||||
def test_non_dict_passthrough_last_wins():
|
def test_non_dict_passthrough_last_wins():
|
||||||
g1 = {"misc": 123}
|
g1 = {"misc": 123}
|
||||||
g2 = {"misc": 456}
|
g2 = {"misc": 456}
|
||||||
|
|
||||||
out = merge_features(g1, g2)
|
out = combine_feature_dicts(g1, g2)
|
||||||
# For non-dict entries the last one wins
|
# For non-dict entries the last one wins
|
||||||
assert out["misc"] == 456
|
assert out["misc"] == 456
|
||||||
|
|||||||
Reference in New Issue
Block a user