mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
test(async): fix feature manipulation (#1957)
* test(async): fix feature manipulation * chore(processor): remove unused functions
This commit is contained in:
@@ -694,7 +694,7 @@ def build_dataset_frame(
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key]
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
@@ -23,8 +23,6 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.constants import OBS_IMAGES
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
|
||||
@@ -154,41 +152,6 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
"""
|
||||
Check if a given array is likely an image (uint8, 3D).
|
||||
|
||||
Args:
|
||||
arr: The array to check.
|
||||
|
||||
Returns:
|
||||
True if the array matches the image criteria, False otherwise.
|
||||
"""
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Separate an observation dictionary into state and image components.
|
||||
|
||||
Args:
|
||||
obs: The observation dictionary.
|
||||
|
||||
Returns:
|
||||
A tuple containing two dictionaries: one for state and one for images.
|
||||
"""
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if "image" in k.lower() or _is_image(v):
|
||||
images[k] = v
|
||||
else:
|
||||
state[k] = v
|
||||
return state, images
|
||||
|
||||
|
||||
# Private Helper Functions (Common Logic)
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
@@ -209,9 +172,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
# Core Conversion Functions
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: PolicyAction | RobotAction | None = None,
|
||||
@@ -279,11 +239,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition:
|
||||
Returns:
|
||||
An `EnvTransition` containing the formatted observation.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
image_observations = {f"{OBS_IMAGES}.{cam}": img for cam, img in images.items()}
|
||||
|
||||
return create_transition(observation={**state, **image_observations})
|
||||
return create_transition(observation=observation)
|
||||
|
||||
|
||||
def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
|
||||
|
||||
@@ -17,7 +17,6 @@ import pickle
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -298,7 +297,6 @@ def test_resize_robot_observation_image():
|
||||
assert resized.max() <= 255
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
||||
def test_prepare_raw_observation():
|
||||
"""Test the preparation of raw robot observation to lerobot format."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
@@ -329,7 +327,6 @@ def test_prepare_raw_observation():
|
||||
assert isinstance(phone_img, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
||||
def test_raw_observation_to_observation_basic():
|
||||
"""Test the main raw_observation_to_observation function."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
@@ -369,7 +366,6 @@ def test_raw_observation_to_observation_basic():
|
||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
"""Test that non-tensor data (like task strings) is preserved."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
@@ -387,7 +383,6 @@ def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
assert isinstance(observation["task"], str)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are properly moved to the specified device."""
|
||||
@@ -405,7 +400,6 @@ def test_raw_observation_to_observation_device_handling():
|
||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
"""Test that the function produces consistent results for the same input."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
@@ -427,7 +421,6 @@ def test_raw_observation_to_observation_deterministic():
|
||||
assert obs1[key] == obs2[key]
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
|
||||
def test_image_processing_pipeline_preserves_content():
|
||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
||||
# Create an image with a specific pattern
|
||||
|
||||
Reference in New Issue
Block a user