mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 15:49:53 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 15960f0b5e | |||
| 8b43339563 | |||
| 5dababd21e | |||
| cbc46467b3 |
@@ -57,6 +57,7 @@ from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
@@ -71,6 +72,7 @@ from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor.core import TransitionKey
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
@@ -88,6 +90,8 @@ def rollout(
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
@@ -161,9 +165,11 @@ def rollout(
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action: np.ndarray = action.to("cpu").numpy()
|
||||
action: np.ndarray = action.to("cpu").numpy()
|
||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
@@ -220,8 +226,10 @@ def rollout(
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
videos_dir: Path | None = None,
|
||||
@@ -298,6 +306,10 @@ def eval_policy(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
@@ -484,13 +496,22 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
|
||||
policy.eval()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||
)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -33,7 +33,19 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -100,8 +112,7 @@ def test_act_processor_normalization():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -111,8 +122,7 @@ def test_act_processor_normalization():
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized
|
||||
@@ -136,8 +146,7 @@ def test_act_processor_cuda():
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -147,8 +156,7 @@ def test_act_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -173,8 +181,7 @@ def test_act_processor_accelerate_scenario():
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -197,8 +204,7 @@ def test_act_processor_multi_gpu():
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -221,8 +227,7 @@ def test_act_processor_without_stats():
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -252,8 +257,7 @@ def test_act_processor_save_and_load():
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
@@ -277,8 +281,7 @@ def test_act_processor_device_placement_preservation():
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
@@ -323,8 +326,7 @@ def test_act_processor_mixed_precision():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -349,8 +351,7 @@ def test_act_processor_batch_consistency():
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
@@ -358,8 +359,7 @@ def test_act_processor_batch_consistency():
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation=observation_batched)
|
||||
transition_batched[TransitionKey.ACTION] = action_batched
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
@@ -407,8 +407,7 @@ def test_act_processor_bfloat16_device_float32_normalizer():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
@@ -22,7 +20,7 @@ def _dummy_batch():
|
||||
|
||||
def test_observation_grouping_roundtrip():
|
||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||
proc = DataProcessorPipeline[dict[str, Any]]([])
|
||||
proc = DataProcessorPipeline([])
|
||||
batch_in = _dummy_batch()
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
@@ -47,12 +45,11 @@ def test_observation_grouping_roundtrip():
|
||||
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
base_batch = _dummy_batch()
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": base_batch["observation.image.left"],
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
"action": torch.tensor([[0.1, 0.2]]),
|
||||
"action": "action_data",
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
@@ -77,7 +74,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.1, 0.2]]))
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
@@ -87,16 +84,15 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
base_batch = _dummy_batch()
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": base_batch["observation.image.left"],
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation_dict,
|
||||
TransitionKey.ACTION: torch.tensor([[0.3, 0.4]]),
|
||||
TransitionKey.ACTION: "action_data",
|
||||
TransitionKey.REWARD: 1.5,
|
||||
TransitionKey.DONE: True,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
@@ -117,7 +113,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert torch.allclose(batch["action"], torch.tensor([[0.3, 0.4]]))
|
||||
assert batch["action"] == "action_data"
|
||||
assert batch["next.reward"] == 1.5
|
||||
assert batch["next.done"]
|
||||
assert not batch["next.truncated"]
|
||||
@@ -127,7 +123,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": torch.tensor([[0.7, 0.8]]),
|
||||
"action": "action_data",
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
@@ -140,7 +136,7 @@ def test_no_observation_keys():
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.7, 0.8]]))
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
@@ -148,7 +144,7 @@ def test_no_observation_keys():
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.7, 0.8]]))
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
@@ -157,13 +153,13 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([[0.9]])}
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.9]]))
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
@@ -175,7 +171,7 @@ def test_minimal_batch():
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.9]]))
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
@@ -208,10 +204,9 @@ def test_empty_batch():
|
||||
|
||||
def test_complex_nested_observation():
|
||||
"""Test with complex nested observation data."""
|
||||
base_batch = _dummy_batch()
|
||||
batch = {
|
||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
"observation.image.left": {"image": base_batch["observation.image.left"], "timestamp": 1234567891},
|
||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
"observation.state": torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
|
||||
@@ -28,7 +28,21 @@ from lerobot.processor import (
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_state_1d_to_2d():
|
||||
@@ -503,7 +517,7 @@ def test_action_non_tensor():
|
||||
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
|
||||
|
||||
# String action (edge case)
|
||||
action_string = "eef.pos.x"
|
||||
action_string = "forward"
|
||||
transition = create_transition(action=action_string)
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.ACTION] == action_string
|
||||
@@ -689,7 +703,7 @@ def test_complementary_data_none():
|
||||
transition = create_transition(complementary_data=None)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] is None
|
||||
|
||||
|
||||
def test_complementary_data_empty():
|
||||
|
||||
@@ -31,7 +31,19 @@ from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -103,8 +115,7 @@ def test_classifier_processor_normalization():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1) # Dummy action/reward
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -135,8 +146,7 @@ def test_classifier_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -147,8 +157,7 @@ def test_classifier_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
reward_transition = create_transition()
|
||||
reward_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
reward_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(reward_transition)
|
||||
|
||||
# Check that output is back on CPU
|
||||
@@ -176,8 +185,7 @@ def test_classifier_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -204,8 +212,7 @@ def test_classifier_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -232,8 +239,7 @@ def test_classifier_processor_without_stats():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -267,8 +273,7 @@ def test_classifier_processor_save_and_load():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,)
|
||||
@@ -303,8 +308,7 @@ def test_classifier_processor_mixed_precision():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(1, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -334,8 +338,7 @@ def test_classifier_processor_batch_data():
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 1)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -360,8 +363,7 @@ def test_classifier_processor_postprocessor_identity():
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
transition = create_transition()
|
||||
transition[TransitionKey.ACTION] = reward
|
||||
transition = create_transition(action=reward)
|
||||
|
||||
# Process through postprocessor
|
||||
processed = postprocessor(transition)
|
||||
|
||||
@@ -20,7 +20,28 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
if reward is not None:
|
||||
transition[TransitionKey.REWARD] = reward
|
||||
if done is not None:
|
||||
transition[TransitionKey.DONE] = done
|
||||
if truncated is not None:
|
||||
transition[TransitionKey.TRUNCATED] = truncated
|
||||
if info is not None:
|
||||
transition[TransitionKey.INFO] = info
|
||||
if complementary_data is not None:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return transition
|
||||
|
||||
|
||||
def test_basic_functionality():
|
||||
@@ -126,14 +147,14 @@ def test_none_values():
|
||||
# Test with None observation
|
||||
transition = create_transition(observation=None, action=torch.randn(5))
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.OBSERVATION] is None
|
||||
assert TransitionKey.OBSERVATION not in result
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None)
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result[TransitionKey.ACTION] is None
|
||||
assert TransitionKey.ACTION not in result
|
||||
|
||||
|
||||
def test_empty_observation():
|
||||
@@ -801,8 +822,8 @@ def test_complementary_data_none():
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Complementary data should be an empty dict (standardized behavior)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
# Complementary data should not be in the result (same as input)
|
||||
assert TransitionKey.COMPLEMENTARY_DATA not in result
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
|
||||
@@ -33,7 +33,19 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -106,8 +118,7 @@ def test_diffusion_processor_with_images():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -138,8 +149,7 @@ def test_diffusion_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -150,8 +160,7 @@ def test_diffusion_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -179,8 +188,7 @@ def test_diffusion_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -207,8 +215,7 @@ def test_diffusion_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -235,8 +242,7 @@ def test_diffusion_processor_without_stats():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -270,8 +276,7 @@ def test_diffusion_processor_save_and_load():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
@@ -317,8 +322,7 @@ def test_diffusion_processor_mixed_precision():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -348,8 +352,7 @@ def test_diffusion_processor_identity_normalization():
|
||||
OBS_IMAGE: image_value.clone(),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -378,8 +381,7 @@ def test_diffusion_processor_batch_consistency():
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -433,8 +435,7 @@ def test_diffusion_processor_bfloat16_device_float32_normalizer():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -34,7 +34,6 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
@@ -53,6 +52,21 @@ class MockTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default PI0 configuration for testing."""
|
||||
config = PI0Config()
|
||||
@@ -205,8 +219,7 @@ def test_pi0_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -262,8 +275,7 @@ def test_pi0_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -319,8 +331,7 @@ def test_pi0_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -415,9 +426,8 @@ def test_pi0_processor_bfloat16_device_float32_normalizer():
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
|
||||
transition = create_transition(
|
||||
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -33,7 +33,19 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -105,8 +117,7 @@ def test_sac_processor_normalization_modes():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
|
||||
action = torch.rand(5) * 2 - 1 # Range [-1, 1]
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -118,8 +129,7 @@ def test_sac_processor_normalization_modes():
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized (but still batched)
|
||||
@@ -143,8 +153,7 @@ def test_sac_processor_cuda():
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -154,8 +163,7 @@ def test_sac_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -180,8 +188,7 @@ def test_sac_processor_accelerate_scenario():
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -209,8 +216,7 @@ def test_sac_processor_multi_gpu():
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -248,8 +254,7 @@ def test_sac_processor_without_stats():
|
||||
# Process should still work
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -279,8 +284,7 @@ def test_sac_processor_save_and_load():
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
@@ -325,8 +329,7 @@ def test_sac_processor_mixed_precision():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -352,8 +355,7 @@ def test_sac_processor_batch_data():
|
||||
batch_size = 32
|
||||
observation = {OBS_STATE: torch.randn(batch_size, 10)}
|
||||
action = torch.randn(batch_size, 5)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -376,14 +378,13 @@ def test_sac_processor_edge_cases():
|
||||
)
|
||||
|
||||
# Test with empty observation
|
||||
transition = create_transition(observation={})
|
||||
transition[TransitionKey.ACTION] = torch.randn(5)
|
||||
transition = create_transition(observation={}, action=torch.randn(5))
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION] == {}
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)})
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
# When action is None, it may still be present with None value
|
||||
@@ -432,8 +433,7 @@ def test_sac_processor_bfloat16_device_float32_normalizer():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -37,7 +37,6 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
@@ -56,6 +55,21 @@ class MockTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SmolVLA configuration for testing."""
|
||||
config = SmolVLAConfig()
|
||||
@@ -214,8 +228,7 @@ def test_smolvla_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -273,8 +286,7 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -332,8 +344,7 @@ def test_smolvla_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -444,9 +455,8 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer():
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(
|
||||
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
@@ -33,7 +33,19 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -111,8 +123,7 @@ def test_vqbet_processor_with_images():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -143,8 +154,7 @@ def test_vqbet_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -155,8 +165,7 @@ def test_vqbet_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -184,8 +193,7 @@ def test_vqbet_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -217,8 +225,7 @@ def test_vqbet_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -260,8 +267,7 @@ def test_vqbet_processor_without_stats():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -294,8 +300,7 @@ def test_vqbet_processor_save_and_load():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
|
||||
@@ -344,8 +349,7 @@ def test_vqbet_processor_mixed_precision():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -375,8 +379,7 @@ def test_vqbet_processor_large_batch():
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 7)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -407,8 +410,7 @@ def test_vqbet_processor_sequential_processing():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
results.append(processed)
|
||||
@@ -465,8 +467,7 @@ def test_vqbet_processor_bfloat16_device_float32_normalizer():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
Reference in New Issue
Block a user