test(async): fix feature manipulation (#1957)

* test(async): fix feature manipulation

* chore(processor): remove unused functions
This commit is contained in:
Steven Palma
2025-09-16 15:49:32 +02:00
committed by GitHub
parent 27a229ea64
commit 772da63a8e
3 changed files with 2 additions and 52 deletions
+1 -1
View File
@@ -694,7 +694,7 @@ def build_dataset_frame(
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]: elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key] frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame return frame
+1 -44
View File
@@ -23,8 +23,6 @@ from typing import Any
import numpy as np import numpy as np
import torch import torch
from lerobot.constants import OBS_IMAGES
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -154,41 +152,6 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
return x return x
def _is_image(arr: Any) -> bool:
"""
Check if a given array is likely an image (uint8, 3D).
Args:
arr: The array to check.
Returns:
True if the array matches the image criteria, False otherwise.
"""
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Separate an observation dictionary into state and image components.
Args:
obs: The observation dictionary.
Returns:
A tuple containing two dictionaries: one for state and one for images.
"""
state, images = {}, {}
for k, v in obs.items():
if "image" in k.lower() or _is_image(v):
images[k] = v
else:
state[k] = v
return state, images
# Private Helper Functions (Common Logic)
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
""" """
Extract complementary data from a batch dictionary. Extract complementary data from a batch dictionary.
@@ -209,9 +172,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
return {**pad_keys, **task_key, **index_key, **task_index_key} return {**pad_keys, **task_key, **index_key, **task_index_key}
# Core Conversion Functions
def create_transition( def create_transition(
observation: dict[str, Any] | None = None, observation: dict[str, Any] | None = None,
action: PolicyAction | RobotAction | None = None, action: PolicyAction | RobotAction | None = None,
@@ -279,11 +239,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition:
Returns: Returns:
An `EnvTransition` containing the formatted observation. An `EnvTransition` containing the formatted observation.
""" """
state, images = _split_obs_to_state_and_images(observation)
image_observations = {f"{OBS_IMAGES}.{cam}": img for cam, img in images.items()} return create_transition(observation=observation)
return create_transition(observation={**state, **image_observations})
def transition_to_robot_action(transition: EnvTransition) -> RobotAction: def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
-7
View File
@@ -17,7 +17,6 @@ import pickle
import time import time
import numpy as np import numpy as np
import pytest
import torch import torch
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
@@ -298,7 +297,6 @@ def test_resize_robot_observation_image():
assert resized.max() <= 255 assert resized.max() <= 255
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
def test_prepare_raw_observation(): def test_prepare_raw_observation():
"""Test the preparation of raw robot observation to lerobot format.""" """Test the preparation of raw robot observation to lerobot format."""
robot_obs = _create_mock_robot_observation() robot_obs = _create_mock_robot_observation()
@@ -329,7 +327,6 @@ def test_prepare_raw_observation():
assert isinstance(phone_img, torch.Tensor) assert isinstance(phone_img, torch.Tensor)
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
def test_raw_observation_to_observation_basic(): def test_raw_observation_to_observation_basic():
"""Test the main raw_observation_to_observation function.""" """Test the main raw_observation_to_observation function."""
robot_obs = _create_mock_robot_observation() robot_obs = _create_mock_robot_observation()
@@ -369,7 +366,6 @@ def test_raw_observation_to_observation_basic():
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0 assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
def test_raw_observation_to_observation_with_non_tensor_data(): def test_raw_observation_to_observation_with_non_tensor_data():
"""Test that non-tensor data (like task strings) is preserved.""" """Test that non-tensor data (like task strings) is preserved."""
robot_obs = _create_mock_robot_observation() robot_obs = _create_mock_robot_observation()
@@ -387,7 +383,6 @@ def test_raw_observation_to_observation_with_non_tensor_data():
assert isinstance(observation["task"], str) assert isinstance(observation["task"], str)
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
@torch.no_grad() @torch.no_grad()
def test_raw_observation_to_observation_device_handling(): def test_raw_observation_to_observation_device_handling():
"""Test that tensors are properly moved to the specified device.""" """Test that tensors are properly moved to the specified device."""
@@ -405,7 +400,6 @@ def test_raw_observation_to_observation_device_handling():
assert value.device.type == device, f"Tensor {key} not on {device}" assert value.device.type == device, f"Tensor {key} not on {device}"
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
def test_raw_observation_to_observation_deterministic(): def test_raw_observation_to_observation_deterministic():
"""Test that the function produces consistent results for the same input.""" """Test that the function produces consistent results for the same input."""
robot_obs = _create_mock_robot_observation() robot_obs = _create_mock_robot_observation()
@@ -427,7 +421,6 @@ def test_raw_observation_to_observation_deterministic():
assert obs1[key] == obs2[key] assert obs1[key] == obs2[key]
@pytest.mark.skip("TODO(Steven): Skipping test - Check new feature manipulation")
def test_image_processing_pipeline_preserves_content(): def test_image_processing_pipeline_preserves_content():
"""Test that the image processing pipeline preserves recognizable patterns.""" """Test that the image processing pipeline preserves recognizable patterns."""
# Create an image with a specific pattern # Create an image with a specific pattern