mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
test(async): fix feature manipulation (#1957)
* test(async): fix feature manipulation * chore(processor): remove unused functions
This commit is contained in:
@@ -694,7 +694,7 @@ def build_dataset_frame(
|
|||||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||||
elif ft["dtype"] in ["image", "video"]:
|
elif ft["dtype"] in ["image", "video"]:
|
||||||
frame[key] = values[key]
|
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,6 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import OBS_IMAGES
|
|
||||||
|
|
||||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
@@ -154,41 +152,6 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _is_image(arr: Any) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a given array is likely an image (uint8, 3D).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
arr: The array to check.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the array matches the image criteria, False otherwise.
|
|
||||||
"""
|
|
||||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
|
||||||
|
|
||||||
|
|
||||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Separate an observation dictionary into state and image components.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obs: The observation dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing two dictionaries: one for state and one for images.
|
|
||||||
"""
|
|
||||||
state, images = {}, {}
|
|
||||||
for k, v in obs.items():
|
|
||||||
if "image" in k.lower() or _is_image(v):
|
|
||||||
images[k] = v
|
|
||||||
else:
|
|
||||||
state[k] = v
|
|
||||||
return state, images
|
|
||||||
|
|
||||||
|
|
||||||
# Private Helper Functions (Common Logic)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Extract complementary data from a batch dictionary.
|
Extract complementary data from a batch dictionary.
|
||||||
@@ -209,9 +172,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||||
|
|
||||||
|
|
||||||
# Core Conversion Functions
|
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
observation: dict[str, Any] | None = None,
|
observation: dict[str, Any] | None = None,
|
||||||
action: PolicyAction | RobotAction | None = None,
|
action: PolicyAction | RobotAction | None = None,
|
||||||
@@ -279,11 +239,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition:
|
|||||||
Returns:
|
Returns:
|
||||||
An `EnvTransition` containing the formatted observation.
|
An `EnvTransition` containing the formatted observation.
|
||||||
"""
|
"""
|
||||||
state, images = _split_obs_to_state_and_images(observation)
|
|
||||||
|
|
||||||
image_observations = {f"{OBS_IMAGES}.{cam}": img for cam, img in images.items()}
|
return create_transition(observation=observation)
|
||||||
|
|
||||||
return create_transition(observation={**state, **image_observations})
|
|
||||||
|
|
||||||
|
|
||||||
def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
|
def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import pickle
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
@@ -298,7 +297,6 @@ def test_resize_robot_observation_image():
|
|||||||
assert resized.max() <= 255
|
assert resized.max() <= 255
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
|
||||||
def test_prepare_raw_observation():
|
def test_prepare_raw_observation():
|
||||||
"""Test the preparation of raw robot observation to lerobot format."""
|
"""Test the preparation of raw robot observation to lerobot format."""
|
||||||
robot_obs = _create_mock_robot_observation()
|
robot_obs = _create_mock_robot_observation()
|
||||||
@@ -329,7 +327,6 @@ def test_prepare_raw_observation():
|
|||||||
assert isinstance(phone_img, torch.Tensor)
|
assert isinstance(phone_img, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
|
||||||
def test_raw_observation_to_observation_basic():
|
def test_raw_observation_to_observation_basic():
|
||||||
"""Test the main raw_observation_to_observation function."""
|
"""Test the main raw_observation_to_observation function."""
|
||||||
robot_obs = _create_mock_robot_observation()
|
robot_obs = _create_mock_robot_observation()
|
||||||
@@ -369,7 +366,6 @@ def test_raw_observation_to_observation_basic():
|
|||||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
|
||||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
def test_raw_observation_to_observation_with_non_tensor_data():
|
||||||
"""Test that non-tensor data (like task strings) is preserved."""
|
"""Test that non-tensor data (like task strings) is preserved."""
|
||||||
robot_obs = _create_mock_robot_observation()
|
robot_obs = _create_mock_robot_observation()
|
||||||
@@ -387,7 +383,6 @@ def test_raw_observation_to_observation_with_non_tensor_data():
|
|||||||
assert isinstance(observation["task"], str)
|
assert isinstance(observation["task"], str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_raw_observation_to_observation_device_handling():
|
def test_raw_observation_to_observation_device_handling():
|
||||||
"""Test that tensors are properly moved to the specified device."""
|
"""Test that tensors are properly moved to the specified device."""
|
||||||
@@ -405,7 +400,6 @@ def test_raw_observation_to_observation_device_handling():
|
|||||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
assert value.device.type == device, f"Tensor {key} not on {device}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
|
||||||
def test_raw_observation_to_observation_deterministic():
|
def test_raw_observation_to_observation_deterministic():
|
||||||
"""Test that the function produces consistent results for the same input."""
|
"""Test that the function produces consistent results for the same input."""
|
||||||
robot_obs = _create_mock_robot_observation()
|
robot_obs = _create_mock_robot_observation()
|
||||||
@@ -427,7 +421,6 @@ def test_raw_observation_to_observation_deterministic():
|
|||||||
assert obs1[key] == obs2[key]
|
assert obs1[key] == obs2[key]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
|
||||||
def test_image_processing_pipeline_preserves_content():
|
def test_image_processing_pipeline_preserves_content():
|
||||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
"""Test that the image processing pipeline preserves recognizable patterns."""
|
||||||
# Create an image with a specific pattern
|
# Create an image with a specific pattern
|
||||||
|
|||||||
Reference in New Issue
Block a user