mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
refactor(converters): implement unified tensor conversion function (#1841)
- Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior. Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -29,17 +30,113 @@ from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNC
|
||||
from .pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if isinstance(x, np.ndarray):
|
||||
# Keep images (uint8 HWC) and python objects as-is
|
||||
if x.dtype == np.uint8 or x.dtype == np.object_:
|
||||
return x
|
||||
# Scalars/arrays to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
# Anything else to float32 tensor
|
||||
return torch.as_tensor(x, dtype=torch.float32)
|
||||
@singledispatch
|
||||
def to_tensor(
|
||||
value: Any,
|
||||
*,
|
||||
dtype: torch.dtype | None = torch.float32,
|
||||
device: torch.device | str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert various data types to PyTorch tensors with configurable options.
|
||||
|
||||
This is a unified tensor conversion function using single dispatch to handle
|
||||
different input types appropriately.
|
||||
|
||||
Args:
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.)
|
||||
dtype: Target tensor dtype. If None, preserves original dtype.
|
||||
device: Target device for the tensor.
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input type is not supported.
|
||||
"""
|
||||
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
|
||||
|
||||
|
||||
@to_tensor.register(torch.Tensor)
|
||||
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle existing PyTorch tensors."""
|
||||
if dtype is not None:
|
||||
value = value.to(dtype=dtype)
|
||||
if device is not None:
|
||||
value = value.to(device=device)
|
||||
return value
|
||||
|
||||
|
||||
@to_tensor.register(np.ndarray)
|
||||
def _(
|
||||
value: np.ndarray,
|
||||
*,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Handle numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
|
||||
if value.ndim == 0:
|
||||
# Numpy scalars should be converted to 0-dimensional tensors
|
||||
scalar_value = value.item()
|
||||
return torch.tensor(scalar_value, dtype=dtype, device=device)
|
||||
|
||||
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
|
||||
tensor = torch.from_numpy(value)
|
||||
|
||||
# Apply dtype conversion if specified
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if device is not None:
|
||||
tensor = tensor.to(device=device)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@to_tensor.register(int)
|
||||
@to_tensor.register(float)
|
||||
@to_tensor.register(np.integer)
|
||||
@to_tensor.register(np.floating)
|
||||
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle scalar values including numpy scalars."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(list)
|
||||
@to_tensor.register(tuple)
|
||||
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle sequences (lists, tuples)."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(dict)
|
||||
def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
"""Handle dictionaries by recursively converting values to tensors."""
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
for key, sub_value in value.items():
|
||||
if sub_value is None:
|
||||
continue
|
||||
|
||||
if isinstance(sub_value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert individual values to tensors
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _from_tensor(x: Any):
|
||||
@@ -88,7 +185,7 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
|
||||
continue
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"{ACTION}.{k}"] = _to_tensor(arr)
|
||||
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
|
||||
|
||||
return make_obs_act_transition(act=act_dict)
|
||||
|
||||
@@ -103,7 +200,7 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"{OBS_STATE}.{k}"] = _to_tensor(arr)
|
||||
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
|
||||
|
||||
@@ -16,6 +16,7 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -59,5 +60,5 @@ class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = torch.from_numpy(action)
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
return torch_action
|
||||
|
||||
@@ -4,12 +4,12 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
@@ -19,37 +19,6 @@ from lerobot.processor.pipeline import (
|
||||
)
|
||||
|
||||
|
||||
def _to_tensor(value: Any, device: torch.device | None = None) -> Tensor:
|
||||
"""Convert common python/numpy/torch types to a torch.float32 tensor.
|
||||
|
||||
Always returns float32; preserves device if provided.
|
||||
"""
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.to(dtype=torch.float32, device=device)
|
||||
if isinstance(value, np.ndarray):
|
||||
# ensure contiguous, cast to float32 then convert
|
||||
return torch.from_numpy(np.ascontiguousarray(value.astype(np.float32))).to(device=device)
|
||||
if isinstance(value, (int, float)):
|
||||
return torch.tensor(value, dtype=torch.float32, device=device)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return torch.tensor(value, dtype=torch.float32, device=device)
|
||||
raise TypeError(f"Unsupported type for stats value: {type(value)}")
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(
|
||||
stats: dict[str, dict[str, Any]], device: torch.device | None = None
|
||||
) -> dict[str, dict[str, Tensor]]:
|
||||
"""Convert numeric stats values to torch tensors, preserving keys."""
|
||||
tensor_stats: dict[str, dict[str, Tensor]] = {}
|
||||
for key, sub in (stats or {}).items():
|
||||
if sub is None:
|
||||
continue
|
||||
tensor_stats[key] = {}
|
||||
for stat_name, value in sub.items():
|
||||
tensor_stats[key][stat_name] = _to_tensor(value, device=device)
|
||||
return tensor_stats
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NormalizationMixin:
|
||||
"""
|
||||
@@ -91,12 +60,12 @@ class _NormalizationMixin:
|
||||
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device)
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
|
||||
def to(self, device: torch.device | str) -> _NormalizationMixin:
|
||||
"""Moves the processor's normalization stats to the specified device and returns self."""
|
||||
self.device = device
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device)
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -165,7 +134,7 @@ class _NormalizationMixin:
|
||||
# Move stats to input device if needed
|
||||
stats_device = next(iter(stats.values())).device
|
||||
if stats_device != input_device:
|
||||
stats = _convert_stats_to_tensors({key: self._tensor_stats[key]}, device=input_device)[key]
|
||||
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
@@ -295,5 +264,5 @@ def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, An
|
||||
if isinstance(step, _NormalizationMixin):
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = _convert_stats_to_tensors(stats, device=step.device)
|
||||
step._tensor_stats = to_tensor(stats, device=step.device)
|
||||
return rp
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
to_output_robot_action,
|
||||
to_tensor,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
)
|
||||
@@ -12,12 +13,12 @@ from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
# Scalars, arrays, and "image-like" uint8 arrays are supported
|
||||
# Scalars, arrays, and uint8 arrays are all converted to tensors
|
||||
img = np.zeros((8, 12, 3), dtype=np.uint8)
|
||||
act = {
|
||||
"ee.x": 0.5, # scalar to torch tensor
|
||||
"delta": np.array([1.0, 2.0]), # ndarray to torch tensor
|
||||
"raw_img": img, # uint8 HWC to passthrough ndarray
|
||||
"raw_img": img, # uint8 HWC to torch tensor
|
||||
}
|
||||
|
||||
tr = to_transition_teleop_action(act)
|
||||
@@ -29,7 +30,7 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
assert "action.delta" in tr[TransitionKey.ACTION]
|
||||
assert "action.raw_img" in tr[TransitionKey.ACTION]
|
||||
|
||||
# Types: scalars/arrays -> torch tensor; images to np.ndarray
|
||||
# Types: all values -> torch tensor
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5)
|
||||
|
||||
@@ -37,8 +38,8 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,)
|
||||
assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0]))
|
||||
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray)
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == torch.float32 # converted from uint8
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3)
|
||||
|
||||
# Observation is created as empty dict by make_transition
|
||||
@@ -194,3 +195,185 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
# Complementary data
|
||||
assert batch["frame_is_pad"] is True
|
||||
assert batch["task"] == "Pick cube"
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
def test_to_tensor_numpy_arrays():
|
||||
"""Test to_tensor with various numpy arrays."""
|
||||
# Regular numpy array
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
result = to_tensor(arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# Different numpy dtypes should convert to float32 by default
|
||||
int_arr = np.array([1, 2, 3], dtype=np.int64)
|
||||
result = to_tensor(int_arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# uint8 arrays (previously "preserved") should now convert
|
||||
uint8_arr = np.array([100, 150, 200], dtype=np.uint8)
|
||||
result = to_tensor(uint8_arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0]))
|
||||
|
||||
|
||||
def test_to_tensor_numpy_scalars():
|
||||
"""Test to_tensor with numpy scalars (0-dimensional arrays)."""
|
||||
# numpy float32 scalar
|
||||
scalar = np.float32(3.14)
|
||||
result = to_tensor(scalar)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.ndim == 0 # Should be 0-dimensional tensor
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(3.14)
|
||||
|
||||
# numpy int32 scalar
|
||||
int_scalar = np.int32(42)
|
||||
result = to_tensor(int_scalar)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.ndim == 0
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(42.0)
|
||||
|
||||
|
||||
def test_to_tensor_python_scalars():
|
||||
"""Test to_tensor with Python scalars."""
|
||||
# Python int
|
||||
result = to_tensor(42)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(42.0)
|
||||
|
||||
# Python float
|
||||
result = to_tensor(3.14)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(3.14)
|
||||
|
||||
|
||||
def test_to_tensor_sequences():
|
||||
"""Test to_tensor with lists and tuples."""
|
||||
# List
|
||||
result = to_tensor([1, 2, 3])
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# Tuple
|
||||
result = to_tensor((4.5, 5.5, 6.5))
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5]))
|
||||
|
||||
|
||||
def test_to_tensor_existing_tensors():
|
||||
"""Test to_tensor with existing PyTorch tensors."""
|
||||
# Tensor with same dtype should pass through with potential device change
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
|
||||
result = to_tensor(tensor)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, tensor)
|
||||
|
||||
# Tensor with different dtype should convert
|
||||
int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
|
||||
result = to_tensor(int_tensor)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_to_tensor_dictionaries():
|
||||
"""Test to_tensor with nested dictionaries."""
|
||||
# Simple dictionary
|
||||
data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42}
|
||||
result = to_tensor(data)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["mean"], torch.Tensor)
|
||||
assert isinstance(result["std"], torch.Tensor)
|
||||
assert isinstance(result["count"], torch.Tensor)
|
||||
assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["std"], torch.tensor([1.0, 2.0]))
|
||||
assert result["count"].item() == pytest.approx(42.0)
|
||||
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result["observation"], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result["observation"]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
def test_to_tensor_none_filtering():
|
||||
"""Test that None values are filtered out from dictionaries."""
|
||||
data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}}
|
||||
result = to_tensor(data)
|
||||
assert "none_value" not in result
|
||||
assert "also_none" not in result["nested"]
|
||||
assert "valid" in result
|
||||
assert "valid" in result["nested"]
|
||||
assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_to_tensor_dtype_parameter():
|
||||
"""Test to_tensor with different dtype parameters."""
|
||||
arr = np.array([1, 2, 3])
|
||||
|
||||
# Default dtype (float32)
|
||||
result = to_tensor(arr)
|
||||
assert result.dtype == torch.float32
|
||||
|
||||
# Explicit float32
|
||||
result = to_tensor(arr, dtype=torch.float32)
|
||||
assert result.dtype == torch.float32
|
||||
|
||||
# Float64
|
||||
result = to_tensor(arr, dtype=torch.float64)
|
||||
assert result.dtype == torch.float64
|
||||
|
||||
# Preserve original dtype
|
||||
float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
||||
result = to_tensor(float64_arr, dtype=None)
|
||||
assert result.dtype == torch.float64
|
||||
|
||||
|
||||
def test_to_tensor_device_parameter():
|
||||
"""Test to_tensor with device parameter."""
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
|
||||
# CPU device (default)
|
||||
result = to_tensor(arr, device="cpu")
|
||||
assert result.device.type == "cpu"
|
||||
|
||||
# CUDA device (if available)
|
||||
if torch.cuda.is_available():
|
||||
result = to_tensor(arr, device="cuda")
|
||||
assert result.device.type == "cuda"
|
||||
|
||||
|
||||
def test_to_tensor_empty_dict():
|
||||
"""Test to_tensor with empty dictionary."""
|
||||
result = to_tensor({})
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_to_tensor_unsupported_type():
|
||||
"""Test to_tensor with unsupported types raises TypeError."""
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor("unsupported_string")
|
||||
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor(object())
|
||||
|
||||
@@ -20,10 +20,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.normalize_processor import (
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
hotswap_stats,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||
@@ -51,7 +51,7 @@ def test_numpy_conversion():
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
||||
@@ -66,7 +66,7 @@ def test_tensor_conversion():
|
||||
"std": torch.tensor([1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert tensor_stats["action"]["mean"].dtype == torch.float32
|
||||
assert tensor_stats["action"]["std"].dtype == torch.float32
|
||||
@@ -79,7 +79,7 @@ def test_scalar_conversion():
|
||||
"std": 0.1,
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
|
||||
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
|
||||
@@ -92,7 +92,7 @@ def test_list_conversion():
|
||||
"max": [1.0, 1.0, 2.0],
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
|
||||
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
|
||||
@@ -105,7 +105,7 @@ def test_unsupported_type():
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError, match="Unsupported type"):
|
||||
_convert_stats_to_tensors(stats)
|
||||
to_tensor(stats)
|
||||
|
||||
|
||||
# Helper functions to create feature maps and norm maps
|
||||
@@ -1017,7 +1017,7 @@ def test_hotswap_stats_basic_functionality():
|
||||
assert new_processor.steps[1].stats == new_stats
|
||||
|
||||
# Check that tensor stats are updated correctly
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
expected_tensor_stats = to_tensor(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
@@ -1223,7 +1223,7 @@ def test_hotswap_stats_multiple_normalizer_types():
|
||||
assert step.stats == new_stats
|
||||
|
||||
# Check tensor stats conversion
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
expected_tensor_stats = to_tensor(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
|
||||
Reference in New Issue
Block a user