mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 23:59:43 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6f1e49dbc4 | |||
| f286eb059c |
@@ -57,7 +57,6 @@ from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
@@ -72,7 +71,6 @@ from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor.core import TransitionKey
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
@@ -90,8 +88,6 @@ def rollout(
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
@@ -165,11 +161,9 @@ def rollout(
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action: np.ndarray = action.to("cpu").numpy()
|
||||
action: np.ndarray = action.to("cpu").numpy()
|
||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
@@ -226,10 +220,8 @@ def rollout(
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
videos_dir: Path | None = None,
|
||||
@@ -306,10 +298,6 @@ def eval_policy(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -496,22 +484,13 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
|
||||
policy.eval()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||
)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -33,19 +33,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -112,7 +100,8 @@ def test_act_processor_normalization():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -122,7 +111,8 @@ def test_act_processor_normalization():
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized
|
||||
@@ -146,7 +136,8 @@ def test_act_processor_cuda():
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -156,7 +147,8 @@ def test_act_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -181,7 +173,8 @@ def test_act_processor_accelerate_scenario():
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -204,7 +197,8 @@ def test_act_processor_multi_gpu():
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -227,7 +221,8 @@ def test_act_processor_without_stats():
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -257,7 +252,8 @@ def test_act_processor_save_and_load():
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
@@ -281,7 +277,8 @@ def test_act_processor_device_placement_preservation():
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
@@ -326,7 +323,8 @@ def test_act_processor_mixed_precision():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -351,7 +349,8 @@ def test_act_processor_batch_consistency():
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
@@ -359,7 +358,8 @@ def test_act_processor_batch_consistency():
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
transition_batched = create_transition(observation=observation_batched)
|
||||
transition_batched[TransitionKey.ACTION] = action_batched
|
||||
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
@@ -407,7 +407,8 @@ def test_act_processor_bfloat16_device_float32_normalizer():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
@@ -20,7 +22,7 @@ def _dummy_batch():
|
||||
|
||||
def test_observation_grouping_roundtrip():
|
||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||
proc = DataProcessorPipeline([])
|
||||
proc = DataProcessorPipeline[dict[str, Any]]([])
|
||||
batch_in = _dummy_batch()
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
@@ -45,11 +47,12 @@ def test_observation_grouping_roundtrip():
|
||||
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
base_batch = _dummy_batch()
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": base_batch["observation.image.left"],
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([[0.1, 0.2]]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
@@ -74,7 +77,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.1, 0.2]]))
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
@@ -84,15 +87,16 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
base_batch = _dummy_batch()
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": base_batch["observation.image.left"],
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation_dict,
|
||||
TransitionKey.ACTION: "action_data",
|
||||
TransitionKey.ACTION: torch.tensor([[0.3, 0.4]]),
|
||||
TransitionKey.REWARD: 1.5,
|
||||
TransitionKey.DONE: True,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
@@ -113,7 +117,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch["action"] == "action_data"
|
||||
assert torch.allclose(batch["action"], torch.tensor([[0.3, 0.4]]))
|
||||
assert batch["next.reward"] == 1.5
|
||||
assert batch["next.done"]
|
||||
assert not batch["next.truncated"]
|
||||
@@ -123,7 +127,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([[0.7, 0.8]]),
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
@@ -136,7 +140,7 @@ def test_no_observation_keys():
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.7, 0.8]]))
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
@@ -144,7 +148,7 @@ def test_no_observation_keys():
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.7, 0.8]]))
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
@@ -153,13 +157,13 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([[0.9]])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.9]]))
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
@@ -171,7 +175,7 @@ def test_minimal_batch():
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.9]]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
@@ -204,9 +208,10 @@ def test_empty_batch():
|
||||
|
||||
def test_complex_nested_observation():
|
||||
"""Test with complex nested observation data."""
|
||||
base_batch = _dummy_batch()
|
||||
batch = {
|
||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
"observation.image.left": {"image": base_batch["observation.image.left"], "timestamp": 1234567891},
|
||||
"observation.state": torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
|
||||
@@ -28,21 +28,7 @@ from lerobot.processor import (
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def test_state_1d_to_2d():
|
||||
@@ -517,7 +503,7 @@ def test_action_non_tensor():
|
||||
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
|
||||
|
||||
# String action (edge case)
|
||||
action_string = "forward"
|
||||
action_string = "eef.pos.x"
|
||||
transition = create_transition(action=action_string)
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.ACTION] == action_string
|
||||
@@ -703,7 +689,7 @@ def test_complementary_data_none():
|
||||
transition = create_transition(complementary_data=None)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] is None
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
|
||||
def test_complementary_data_empty():
|
||||
|
||||
@@ -31,19 +31,7 @@ from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -115,7 +103,8 @@ def test_classifier_processor_normalization():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1) # Dummy action/reward
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -146,7 +135,8 @@ def test_classifier_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -157,7 +147,8 @@ def test_classifier_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
reward_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
reward_transition = create_transition()
|
||||
reward_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(reward_transition)
|
||||
|
||||
# Check that output is back on CPU
|
||||
@@ -185,7 +176,8 @@ def test_classifier_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -212,7 +204,8 @@ def test_classifier_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -239,7 +232,8 @@ def test_classifier_processor_without_stats():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -273,7 +267,8 @@ def test_classifier_processor_save_and_load():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,)
|
||||
@@ -308,7 +303,8 @@ def test_classifier_processor_mixed_precision():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(1, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -338,7 +334,8 @@ def test_classifier_processor_batch_data():
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -363,7 +360,8 @@ def test_classifier_processor_postprocessor_identity():
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
transition = create_transition(action=reward)
|
||||
transition = create_transition()
|
||||
transition[TransitionKey.ACTION] = reward
|
||||
|
||||
# Process through postprocessor
|
||||
processed = postprocessor(transition)
|
||||
|
||||
@@ -20,28 +20,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
if reward is not None:
|
||||
transition[TransitionKey.REWARD] = reward
|
||||
if done is not None:
|
||||
transition[TransitionKey.DONE] = done
|
||||
if truncated is not None:
|
||||
transition[TransitionKey.TRUNCATED] = truncated
|
||||
if info is not None:
|
||||
transition[TransitionKey.INFO] = info
|
||||
if complementary_data is not None:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def test_basic_functionality():
|
||||
@@ -147,14 +126,14 @@ def test_none_values():
|
||||
# Test with None observation
|
||||
transition = create_transition(observation=None, action=torch.randn(5))
|
||||
result = processor(transition)
|
||||
assert TransitionKey.OBSERVATION not in result
|
||||
assert result[TransitionKey.OBSERVATION] is None
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None)
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert TransitionKey.ACTION not in result
|
||||
assert result[TransitionKey.ACTION] is None
|
||||
|
||||
|
||||
def test_empty_observation():
|
||||
@@ -822,8 +801,8 @@ def test_complementary_data_none():
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Complementary data should not be in the result (same as input)
|
||||
assert TransitionKey.COMPLEMENTARY_DATA not in result
|
||||
# Complementary data should be an empty dict (standardized behavior)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
|
||||
@@ -33,19 +33,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -118,7 +106,8 @@ def test_diffusion_processor_with_images():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -149,7 +138,8 @@ def test_diffusion_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -160,7 +150,8 @@ def test_diffusion_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -188,7 +179,8 @@ def test_diffusion_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -215,7 +207,8 @@ def test_diffusion_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -242,7 +235,8 @@ def test_diffusion_processor_without_stats():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -276,7 +270,8 @@ def test_diffusion_processor_save_and_load():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
@@ -322,7 +317,8 @@ def test_diffusion_processor_mixed_precision():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -352,7 +348,8 @@ def test_diffusion_processor_identity_normalization():
|
||||
OBS_IMAGE: image_value.clone(),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -381,7 +378,8 @@ def test_diffusion_processor_batch_consistency():
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -435,7 +433,8 @@ def test_diffusion_processor_bfloat16_device_float32_normalizer():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -34,6 +34,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
@@ -52,21 +53,6 @@ class MockTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default PI0 configuration for testing."""
|
||||
config = PI0Config()
|
||||
@@ -219,7 +205,8 @@ def test_pi0_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -275,7 +262,8 @@ def test_pi0_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -331,7 +319,8 @@ def test_pi0_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -426,8 +415,9 @@ def test_pi0_processor_bfloat16_device_float32_normalizer():
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -33,19 +33,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -117,7 +105,8 @@ def test_sac_processor_normalization_modes():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
|
||||
action = torch.rand(5) * 2 - 1 # Range [-1, 1]
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -129,7 +118,8 @@ def test_sac_processor_normalization_modes():
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized (but still batched)
|
||||
@@ -153,7 +143,8 @@ def test_sac_processor_cuda():
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -163,7 +154,8 @@ def test_sac_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -188,7 +180,8 @@ def test_sac_processor_accelerate_scenario():
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -216,7 +209,8 @@ def test_sac_processor_multi_gpu():
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -254,7 +248,8 @@ def test_sac_processor_without_stats():
|
||||
# Process should still work
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -284,7 +279,8 @@ def test_sac_processor_save_and_load():
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
@@ -329,7 +325,8 @@ def test_sac_processor_mixed_precision():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -355,7 +352,8 @@ def test_sac_processor_batch_data():
|
||||
batch_size = 32
|
||||
observation = {OBS_STATE: torch.randn(batch_size, 10)}
|
||||
action = torch.randn(batch_size, 5)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -378,13 +376,14 @@ def test_sac_processor_edge_cases():
|
||||
)
|
||||
|
||||
# Test with empty observation
|
||||
transition = create_transition(observation={}, action=torch.randn(5))
|
||||
transition = create_transition(observation={})
|
||||
transition[TransitionKey.ACTION] = torch.randn(5)
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION] == {}
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)})
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
# When action is None, it may still be present with None value
|
||||
@@ -433,7 +432,8 @@ def test_sac_processor_bfloat16_device_float32_normalizer():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -37,6 +37,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
@@ -55,21 +56,6 @@ class MockTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SmolVLA configuration for testing."""
|
||||
config = SmolVLAConfig()
|
||||
@@ -228,7 +214,8 @@ def test_smolvla_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -286,7 +273,8 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -344,7 +332,8 @@ def test_smolvla_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -455,8 +444,9 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer():
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -33,19 +33,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -123,7 +111,8 @@ def test_vqbet_processor_with_images():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -154,7 +143,8 @@ def test_vqbet_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -165,7 +155,8 @@ def test_vqbet_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -193,7 +184,8 @@ def test_vqbet_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -225,7 +217,8 @@ def test_vqbet_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -267,7 +260,8 @@ def test_vqbet_processor_without_stats():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -300,7 +294,8 @@ def test_vqbet_processor_save_and_load():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
|
||||
@@ -349,7 +344,8 @@ def test_vqbet_processor_mixed_precision():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -379,7 +375,8 @@ def test_vqbet_processor_large_batch():
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 7)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -410,7 +407,8 @@ def test_vqbet_processor_sequential_processing():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
results.append(processed)
|
||||
@@ -467,7 +465,8 @@ def test_vqbet_processor_bfloat16_device_float32_normalizer():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
Reference in New Issue
Block a user