Compare commits

..

4 Commits

Author SHA1 Message Date
AdilZouitine 15960f0b5e refactor(utils): enhance task handling in add_envs_task function
- Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings.
- Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability.
- Streamlined the observation dictionary handling by ensuring consistent data types for task attributes.
2025-09-10 10:05:43 +02:00
AdilZouitine 8b43339563 debug 2025-09-10 10:05:43 +02:00
AdilZouitine 5dababd21e refactor(eval): remove redundant observation device conversion in rollout function
- Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability.
- This change simplifies the observation handling process, aligning with the preference for clearer solutions.
2025-09-10 10:05:43 +02:00
AdilZouitine cbc46467b3 refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions
- Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline.
- Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow.
2025-09-10 10:05:43 +02:00
11 changed files with 241 additions and 167 deletions
+23 -2
View File
@@ -57,6 +57,7 @@ from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Any
from typing import Any
import einops
import gymnasium as gym
@@ -71,6 +72,7 @@ from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.core import TransitionKey
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -88,6 +90,8 @@ def rollout(
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -161,9 +165,11 @@ def rollout(
with torch.inference_mode():
action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Convert to CPU / numpy.
action: np.ndarray = action.to("cpu").numpy()
action: np.ndarray = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
@@ -220,8 +226,10 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
@@ -298,6 +306,10 @@ def eval_policy(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
env=env,
policy=policy,
preprocessor=preprocessor,
@@ -484,13 +496,22 @@ def eval_main(cfg: EvalPipelineConfig):
env_cfg=cfg.env,
)
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
env=env,
policy=policy,
preprocessor=preprocessor,
+26 -27
View File
@@ -33,7 +33,19 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
return transition
def create_default_config():
@@ -100,8 +112,7 @@ def test_act_processor_normalization():
# Create test data
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -111,8 +122,7 @@ def test_act_processor_normalization():
assert processed[TransitionKey.ACTION].shape == (1, 4)
# Process action through postprocessor
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is unnormalized
@@ -136,8 +146,7 @@ def test_act_processor_cuda():
# Create CPU data
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -147,8 +156,7 @@ def test_act_processor_cuda():
assert processed[TransitionKey.ACTION].device.type == "cuda"
# Process through postprocessor
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is back on CPU
@@ -173,8 +181,7 @@ def test_act_processor_accelerate_scenario():
device = torch.device("cuda:0")
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
action = torch.randn(1, 4).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -197,8 +204,7 @@ def test_act_processor_multi_gpu():
device = torch.device("cuda:1")
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
action = torch.randn(1, 4).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -221,8 +227,7 @@ def test_act_processor_without_stats():
# Process should still work (but won't normalize without stats)
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed is not None
@@ -252,8 +257,7 @@ def test_act_processor_save_and_load():
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = loaded_preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
@@ -277,8 +281,7 @@ def test_act_processor_device_placement_preservation():
# Process CPU data
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
@@ -323,8 +326,7 @@ def test_act_processor_mixed_precision():
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
action = torch.randn(4, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -349,8 +351,7 @@ def test_act_processor_batch_consistency():
# Test single sample (unbatched)
observation = {OBS_STATE: torch.randn(7)}
action = torch.randn(4)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
@@ -358,8 +359,7 @@ def test_act_processor_batch_consistency():
# Test already batched data
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
action_batched = torch.randn(8, 4)
transition_batched = create_transition(observation=observation_batched)
transition_batched[TransitionKey.ACTION] = action_batched
transition_batched = create_transition(observation_batched, action_batched)
processed_batched = preprocessor(transition_batched)
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
@@ -407,8 +407,7 @@ def test_act_processor_bfloat16_device_float32_normalizer():
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
action = torch.randn(4, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
+14 -19
View File
@@ -1,5 +1,3 @@
from typing import Any
import torch
from lerobot.processor import DataProcessorPipeline, TransitionKey
@@ -22,7 +20,7 @@ def _dummy_batch():
def test_observation_grouping_roundtrip():
"""Test that observation.* keys are properly grouped and ungrouped."""
proc = DataProcessorPipeline[dict[str, Any]]([])
proc = DataProcessorPipeline([])
batch_in = _dummy_batch()
batch_out = proc(batch_in)
@@ -47,12 +45,11 @@ def test_observation_grouping_roundtrip():
def test_batch_to_transition_observation_grouping():
"""Test that batch_to_transition correctly groups observation.* keys."""
base_batch = _dummy_batch()
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": base_batch["observation.image.left"],
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
"action": torch.tensor([[0.1, 0.2]]),
"action": "action_data",
"next.reward": 1.5,
"next.done": True,
"next.truncated": False,
@@ -77,7 +74,7 @@ def test_batch_to_transition_observation_grouping():
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.1, 0.2]]))
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 1.5
assert transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
@@ -87,16 +84,15 @@ def test_batch_to_transition_observation_grouping():
def test_transition_to_batch_observation_flattening():
"""Test that transition_to_batch correctly flattens observation dict."""
base_batch = _dummy_batch()
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": base_batch["observation.image.left"],
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
}
transition = {
TransitionKey.OBSERVATION: observation_dict,
TransitionKey.ACTION: torch.tensor([[0.3, 0.4]]),
TransitionKey.ACTION: "action_data",
TransitionKey.REWARD: 1.5,
TransitionKey.DONE: True,
TransitionKey.TRUNCATED: False,
@@ -117,7 +113,7 @@ def test_transition_to_batch_observation_flattening():
assert batch["observation.state"] == [1, 2, 3, 4]
# Check other fields are mapped to next.* format
assert torch.allclose(batch["action"], torch.tensor([[0.3, 0.4]]))
assert batch["action"] == "action_data"
assert batch["next.reward"] == 1.5
assert batch["next.done"]
assert not batch["next.truncated"]
@@ -127,7 +123,7 @@ def test_transition_to_batch_observation_flattening():
def test_no_observation_keys():
"""Test behavior when there are no observation.* keys."""
batch = {
"action": torch.tensor([[0.7, 0.8]]),
"action": "action_data",
"next.reward": 2.0,
"next.done": False,
"next.truncated": True,
@@ -140,7 +136,7 @@ def test_no_observation_keys():
assert transition[TransitionKey.OBSERVATION] is None
# Check other fields
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.7, 0.8]]))
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 2.0
assert not transition[TransitionKey.DONE]
assert transition[TransitionKey.TRUNCATED]
@@ -148,7 +144,7 @@ def test_no_observation_keys():
# Round trip should work
reconstructed_batch = transition_to_batch(transition)
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.7, 0.8]]))
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
assert reconstructed_batch["next.truncated"]
@@ -157,13 +153,13 @@ def test_no_observation_keys():
def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": torch.tensor([[0.9]])}
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = batch_to_transition(batch)
# Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.9]]))
assert transition[TransitionKey.ACTION] == "minimal_action"
# Check defaults
assert transition[TransitionKey.REWARD] == 0.0
@@ -175,7 +171,7 @@ def test_minimal_batch():
# Round trip
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.9]]))
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
@@ -208,10 +204,9 @@ def test_empty_batch():
def test_complex_nested_observation():
"""Test with complex nested observation data."""
base_batch = _dummy_batch()
batch = {
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
"observation.image.left": {"image": base_batch["observation.image.left"], "timestamp": 1234567891},
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
"observation.state": torch.randn(7),
"action": torch.randn(8),
"next.reward": 3.14,
+17 -3
View File
@@ -28,7 +28,21 @@ from lerobot.processor import (
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.processor.converters import create_transition
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_state_1d_to_2d():
@@ -503,7 +517,7 @@ def test_action_non_tensor():
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
# String action (edge case)
action_string = "eef.pos.x"
action_string = "forward"
transition = create_transition(action=action_string)
result = processor(transition)
assert result[TransitionKey.ACTION] == action_string
@@ -689,7 +703,7 @@ def test_complementary_data_none():
transition = create_transition(complementary_data=None)
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
assert result[TransitionKey.COMPLEMENTARY_DATA] is None
def test_complementary_data_empty():
+23 -21
View File
@@ -31,7 +31,19 @@ from lerobot.processor import (
NormalizerProcessorStep,
TransitionKey,
)
from lerobot.processor.converters import create_transition
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
return transition
def create_default_config():
@@ -103,8 +115,7 @@ def test_classifier_processor_normalization():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(1) # Dummy action/reward
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -135,8 +146,7 @@ def test_classifier_processor_cuda():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(1)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -147,8 +157,7 @@ def test_classifier_processor_cuda():
assert processed[TransitionKey.ACTION].device.type == "cuda"
# Process through postprocessor
reward_transition = create_transition()
reward_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
reward_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(reward_transition)
# Check that output is back on CPU
@@ -176,8 +185,7 @@ def test_classifier_processor_accelerate_scenario():
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
}
action = torch.randn(1).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -204,8 +212,7 @@ def test_classifier_processor_multi_gpu():
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
}
action = torch.randn(1).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -232,8 +239,7 @@ def test_classifier_processor_without_stats():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(1)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed is not None
@@ -267,8 +273,7 @@ def test_classifier_processor_save_and_load():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(1)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = loaded_preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,)
@@ -303,8 +308,7 @@ def test_classifier_processor_mixed_precision():
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(1, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -334,8 +338,7 @@ def test_classifier_processor_batch_data():
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
}
action = torch.randn(batch_size, 1)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -360,8 +363,7 @@ def test_classifier_processor_postprocessor_identity():
# Create test data for postprocessor
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
transition = create_transition()
transition[TransitionKey.ACTION] = reward
transition = create_transition(action=reward)
# Process through postprocessor
processed = postprocessor(transition)
+26 -5
View File
@@ -20,7 +20,28 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
from lerobot.processor.converters import create_transition
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
if reward is not None:
transition[TransitionKey.REWARD] = reward
if done is not None:
transition[TransitionKey.DONE] = done
if truncated is not None:
transition[TransitionKey.TRUNCATED] = truncated
if info is not None:
transition[TransitionKey.INFO] = info
if complementary_data is not None:
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return transition
def test_basic_functionality():
@@ -126,14 +147,14 @@ def test_none_values():
# Test with None observation
transition = create_transition(observation=None, action=torch.randn(5))
result = processor(transition)
assert result[TransitionKey.OBSERVATION] is None
assert TransitionKey.OBSERVATION not in result
assert result[TransitionKey.ACTION].device.type == "cpu"
# Test with None action
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
assert result[TransitionKey.ACTION] is None
assert TransitionKey.ACTION not in result
def test_empty_observation():
@@ -801,8 +822,8 @@ def test_complementary_data_none():
result = processor(transition)
# Complementary data should be an empty dict (standardized behavior)
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
# Complementary data should not be in the result (same as input)
assert TransitionKey.COMPLEMENTARY_DATA not in result
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
+24 -23
View File
@@ -33,7 +33,19 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
return transition
def create_default_config():
@@ -106,8 +118,7 @@ def test_diffusion_processor_with_images():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -138,8 +149,7 @@ def test_diffusion_processor_cuda():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -150,8 +160,7 @@ def test_diffusion_processor_cuda():
assert processed[TransitionKey.ACTION].device.type == "cuda"
# Process through postprocessor
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is back on CPU
@@ -179,8 +188,7 @@ def test_diffusion_processor_accelerate_scenario():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -207,8 +215,7 @@ def test_diffusion_processor_multi_gpu():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -235,8 +242,7 @@ def test_diffusion_processor_without_stats():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed is not None
@@ -270,8 +276,7 @@ def test_diffusion_processor_save_and_load():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = loaded_preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
@@ -317,8 +322,7 @@ def test_diffusion_processor_mixed_precision():
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -348,8 +352,7 @@ def test_diffusion_processor_identity_normalization():
OBS_IMAGE: image_value.clone(),
}
action = torch.randn(6)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -378,8 +381,7 @@ def test_diffusion_processor_batch_consistency():
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224),
}
action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
@@ -433,8 +435,7 @@ def test_diffusion_processor_bfloat16_device_float32_normalizer():
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
+19 -9
View File
@@ -34,7 +34,6 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
class MockTokenizerProcessorStep(ProcessorStep):
@@ -53,6 +52,21 @@ class MockTokenizerProcessorStep(ProcessorStep):
return features
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
elif key == "complementary_data":
transition[TransitionKey.COMPLEMENTARY_DATA] = value
return transition
def create_default_config():
"""Create a default PI0 configuration for testing."""
config = PI0Config()
@@ -205,8 +219,7 @@ def test_pi0_processor_cuda():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action, complementary_data={"task": "test task"})
# Process through preprocessor
processed = preprocessor(transition)
@@ -262,8 +275,7 @@ def test_pi0_processor_accelerate_scenario():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
# Process through preprocessor
processed = preprocessor(transition)
@@ -319,8 +331,7 @@ def test_pi0_processor_multi_gpu():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
# Process through preprocessor
processed = preprocessor(transition)
@@ -415,9 +426,8 @@ def test_pi0_processor_bfloat16_device_float32_normalizer():
}
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
transition = create_transition(
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
transition[TransitionKey.ACTION] = action
# Process through full pipeline
processed = preprocessor(transition)
+26 -26
View File
@@ -33,7 +33,19 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
return transition
def create_default_config():
@@ -105,8 +117,7 @@ def test_sac_processor_normalization_modes():
# Create test data
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
action = torch.rand(5) * 2 - 1 # Range [-1, 1]
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -118,8 +129,7 @@ def test_sac_processor_normalization_modes():
assert processed[TransitionKey.ACTION].shape == (1, 5)
# Process action through postprocessor
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is unnormalized (but still batched)
@@ -143,8 +153,7 @@ def test_sac_processor_cuda():
# Create CPU data
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -154,8 +163,7 @@ def test_sac_processor_cuda():
assert processed[TransitionKey.ACTION].device.type == "cuda"
# Process through postprocessor
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is back on CPU
@@ -180,8 +188,7 @@ def test_sac_processor_accelerate_scenario():
device = torch.device("cuda:0")
observation = {OBS_STATE: torch.randn(10).to(device)}
action = torch.randn(5).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -209,8 +216,7 @@ def test_sac_processor_multi_gpu():
device = torch.device("cuda:1")
observation = {OBS_STATE: torch.randn(10).to(device)}
action = torch.randn(5).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -248,8 +254,7 @@ def test_sac_processor_without_stats():
# Process should still work
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed is not None
@@ -279,8 +284,7 @@ def test_sac_processor_save_and_load():
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = loaded_preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
@@ -325,8 +329,7 @@ def test_sac_processor_mixed_precision():
# Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
action = torch.randn(5, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -352,8 +355,7 @@ def test_sac_processor_batch_data():
batch_size = 32
observation = {OBS_STATE: torch.randn(batch_size, 10)}
action = torch.randn(batch_size, 5)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -376,14 +378,13 @@ def test_sac_processor_edge_cases():
)
# Test with empty observation
transition = create_transition(observation={})
transition[TransitionKey.ACTION] = torch.randn(5)
transition = create_transition(observation={}, action=torch.randn(5))
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION] == {}
assert processed[TransitionKey.ACTION].shape == (1, 5)
# Test with None action
transition = create_transition(observation={OBS_STATE: torch.randn(10)})
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
# When action is None, it may still be present with None value
@@ -432,8 +433,7 @@ def test_sac_processor_bfloat16_device_float32_normalizer():
# Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
action = torch.randn(5, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
+19 -9
View File
@@ -37,7 +37,6 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
class MockTokenizerProcessorStep(ProcessorStep):
@@ -56,6 +55,21 @@ class MockTokenizerProcessorStep(ProcessorStep):
return features
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
elif key == "complementary_data":
transition[TransitionKey.COMPLEMENTARY_DATA] = value
return transition
def create_default_config():
"""Create a default SmolVLA configuration for testing."""
config = SmolVLAConfig()
@@ -214,8 +228,7 @@ def test_smolvla_processor_cuda():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(7)
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action, complementary_data={"task": "test task"})
# Process through preprocessor
processed = preprocessor(transition)
@@ -273,8 +286,7 @@ def test_smolvla_processor_accelerate_scenario():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 7).to(device)
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
# Process through preprocessor
processed = preprocessor(transition)
@@ -332,8 +344,7 @@ def test_smolvla_processor_multi_gpu():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 7).to(device)
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
# Process through preprocessor
processed = preprocessor(transition)
@@ -444,9 +455,8 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer():
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
transition[TransitionKey.ACTION] = action
# Process through full pipeline
processed = preprocessor(transition)
+24 -23
View File
@@ -33,7 +33,19 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
return transition
def create_default_config():
@@ -111,8 +123,7 @@ def test_vqbet_processor_with_images():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(7)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -143,8 +154,7 @@ def test_vqbet_processor_cuda():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(7)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -155,8 +165,7 @@ def test_vqbet_processor_cuda():
assert processed[TransitionKey.ACTION].device.type == "cuda"
# Process through postprocessor
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
action_transition = create_transition(action=processed[TransitionKey.ACTION])
postprocessed = postprocessor(action_transition)
# Check that action is back on CPU
@@ -184,8 +193,7 @@ def test_vqbet_processor_accelerate_scenario():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 7).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -217,8 +225,7 @@ def test_vqbet_processor_multi_gpu():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 7).to(device)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -260,8 +267,7 @@ def test_vqbet_processor_without_stats():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(7)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
assert processed is not None
@@ -294,8 +300,7 @@ def test_vqbet_processor_save_and_load():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(7)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = loaded_preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
@@ -344,8 +349,7 @@ def test_vqbet_processor_mixed_precision():
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -375,8 +379,7 @@ def test_vqbet_processor_large_batch():
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
}
action = torch.randn(batch_size, 7)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through preprocessor
processed = preprocessor(transition)
@@ -407,8 +410,7 @@ def test_vqbet_processor_sequential_processing():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(7)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
processed = preprocessor(transition)
results.append(processed)
@@ -465,8 +467,7 @@ def test_vqbet_processor_bfloat16_device_float32_normalizer():
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)