refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions
+257 -175
View File
@@ -112,24 +112,23 @@ RobotProcessor solves these issues by providing a declarative pipeline approach
RobotProcessor works with two data formats:
### 1. EnvTransition Tuple Format
### 1. EnvTransition Dictionary Format
An `EnvTransition` is a 7-tuple that represents a complete transition in the environment:
An `EnvTransition` is a dictionary that represents a complete transition in the environment:
```python
from lerobot.processor.pipeline import TransitionIndex
from lerobot.processor.pipeline import TransitionKey
# EnvTransition structure:
# (observation, action, reward, done, truncated, info, complementary_data)
transition = (
{"observation.image": ..., "observation.state": ...}, # observation at time t
[0.1, -0.2, 0.3], # action taken at time t
1.0, # reward received
False, # episode done flag
False, # episode truncated flag
{"success": True}, # additional info from environment
{"step_idx": 42} # complementary_data for inter-step communication
)
transition = {
TransitionKey.OBSERVATION: {"observation.image": ..., "observation.state": ...}, # observation at time t
TransitionKey.ACTION: [0.1, -0.2, 0.3], # action taken at time t
TransitionKey.REWARD: 1.0, # reward received
TransitionKey.DONE: False, # episode done flag
TransitionKey.TRUNCATED: False, # episode truncated flag
TransitionKey.INFO: {"success": True}, # additional info from environment
TransitionKey.COMPLEMENTARY_DATA: {"step_idx": 42} # complementary_data for inter-step communication
}
```
### 2. Batch Dictionary Format
@@ -160,9 +159,17 @@ from lerobot.processor.observation_processor import ImageProcessor
processor = RobotProcessor([ImageProcessor()])
# Works with EnvTransition tuples
transition = ({"pixels": image_array}, None, 0.0, False, False, {}, {})
processed_transition = processor(transition) # Returns EnvTransition tuple
# Works with EnvTransition dictionaries
transition = {
TransitionKey.OBSERVATION: {"pixels": image_array},
TransitionKey.ACTION: None,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {}
}
processed_transition = processor(transition) # Returns EnvTransition dictionary
# Also works with batch dictionaries
batch = {
@@ -176,25 +183,25 @@ batch = {
processed_batch = processor(batch) # Returns batch dictionary
```
### Using TransitionIndex
### Using TransitionKey
Instead of using magic numbers to access tuple elements, use the `TransitionIndex` enum:
Use the `TransitionKey` enum to access dictionary elements:
```python
from lerobot.processor.pipeline import TransitionIndex
from lerobot.processor.pipeline import TransitionKey
# Bad - using magic numbers
obs = transition[0]
action = transition[1]
# Good - using TransitionKey
obs = transition[TransitionKey.OBSERVATION]
action = transition[TransitionKey.ACTION]
reward = transition[TransitionKey.REWARD]
done = transition[TransitionKey.DONE]
truncated = transition[TransitionKey.TRUNCATED]
info = transition[TransitionKey.INFO]
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
# Good - using TransitionIndex
obs = transition[TransitionIndex.OBSERVATION]
action = transition[TransitionIndex.ACTION]
reward = transition[TransitionIndex.REWARD]
done = transition[TransitionIndex.DONE]
truncated = transition[TransitionIndex.TRUNCATED]
info = transition[TransitionIndex.INFO]
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
# Alternative - using .get() for optional access
obs = transition.get(TransitionKey.OBSERVATION)
action = transition.get(TransitionKey.ACTION)
```
### Default Conversion Functions
@@ -203,43 +210,49 @@ RobotProcessor uses these default conversion functions:
```python
def _default_batch_to_transition(batch):
"""Default conversion from batch dict to EnvTransition tuple."""
"""Default conversion from batch dict to EnvTransition dictionary."""
# Extract observation keys (anything starting with "observation.")
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
observation = None
if observation_keys:
observation = {}
# Keep observation.* keys as-is (don't remove "observation." prefix)
for key, value in observation_keys.items():
observation[key] = value
# Extract padding and task keys for complementary data
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
return (
observation,
batch.get("action"),
batch.get("next.reward", 0.0), # Note: "next.reward" not "reward"
batch.get("next.done", False), # Note: "next.done" not "done"
batch.get("next.truncated", False), # Note: "next.truncated" not "truncated"
batch.get("info", {}),
{}, # Empty complementary_data
)
transition = {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: batch.get("action"),
TransitionKey.REWARD: batch.get("next.reward", 0.0),
TransitionKey.DONE: batch.get("next.done", False),
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
TransitionKey.INFO: batch.get("info", {}),
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
return transition
def _default_transition_to_batch(transition):
"""Default conversion from EnvTransition tuple to batch dict."""
obs, action, reward, done, truncated, info, _ = transition
"""Default conversion from EnvTransition dictionary to batch dict."""
batch = {
"action": action,
"next.reward": reward, # Note: "next.reward" not "reward"
"next.done": done, # Note: "next.done" not "done"
"next.truncated": truncated, # Note: "next.truncated" not "truncated"
"info": info,
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}
# Flatten observation dict (keep observation.* keys as-is)
if isinstance(obs, dict):
for key, value in obs.items():
batch[key] = value
# Add padding and task data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
batch.update(pad_data)
if "task" in complementary_data:
batch["task"] = complementary_data["task"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
return batch
```
@@ -250,33 +263,32 @@ You can customize how RobotProcessor converts between formats:
```python
def custom_batch_to_transition(batch):
"""Custom conversion from batch dict to EnvTransition tuple."""
"""Custom conversion from batch dict to EnvTransition dictionary."""
# Extract observation keys (anything starting with "observation.")
observation = {k: v for k, v in batch.items() if k.startswith("observation.")}
return (
observation,
batch.get("action"),
batch.get("reward", 0.0), # Use "reward" instead of "next.reward"
batch.get("done", False), # Use "done" instead of "next.done"
batch.get("truncated", False),
batch.get("info", {}),
batch.get("complementary_data", {})
)
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: batch.get("action"),
TransitionKey.REWARD: batch.get("reward", 0.0), # Use "reward" instead of "next.reward"
TransitionKey.DONE: batch.get("done", False), # Use "done" instead of "next.done"
TransitionKey.TRUNCATED: batch.get("truncated", False),
TransitionKey.INFO: batch.get("info", {}),
TransitionKey.COMPLEMENTARY_DATA: batch.get("complementary_data", {})
}
def custom_transition_to_batch(transition):
"""Custom conversion from EnvTransition tuple to batch dict."""
obs, action, reward, done, truncated, info, comp_data = transition
"""Custom conversion from EnvTransition dictionary to batch dict."""
batch = {
"action": action,
"reward": reward, # Use "reward" instead of "next.reward"
"done": done, # Use "done" instead of "next.done"
"truncated": truncated,
"info": info,
"action": transition.get(TransitionKey.ACTION),
"reward": transition.get(TransitionKey.REWARD), # Use "reward" instead of "next.reward"
"done": transition.get(TransitionKey.DONE), # Use "done" instead of "next.done"
"truncated": transition.get(TransitionKey.TRUNCATED),
"info": transition.get(TransitionKey.INFO),
}
# Flatten observation dict
obs = transition.get(TransitionKey.OBSERVATION)
if obs:
batch.update(obs)
@@ -292,21 +304,21 @@ processor = RobotProcessor(
### Advanced: Controlling Output Format with `to_output`
The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` tuples even when you input batch dictionaries:
The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` dictionaries even when you input batch dictionaries:
```python
# Identity function to always return EnvTransition tuples
# Identity function to always return EnvTransition dictionaries
def keep_as_transition(transition):
"""Always return EnvTransition tuple regardless of input format."""
"""Always return EnvTransition dictionary regardless of input format."""
return transition
# Processor that always outputs EnvTransition tuples
# Processor that always outputs EnvTransition dictionaries
processor = RobotProcessor(
steps=[ImageProcessor(), StateProcessor()],
to_output=keep_as_transition # Always return tuple format
to_output=keep_as_transition # Always return dictionary format
)
# Even when called with batch dict, returns EnvTransition tuple
# Even when called with batch dict, returns EnvTransition dictionary
batch = {
"observation.image": image_tensor,
"action": action_tensor,
@@ -316,13 +328,13 @@ batch = {
"info": info_dict
}
result = processor(batch) # Returns EnvTransition tuple, not batch dict!
print(type(result)) # <class 'tuple'>
result = processor(batch) # Returns EnvTransition dictionary, not batch dict!
print(type(result)) # <class 'dict'>
```
### Real-World Example: Environment Interaction
This is particularly useful for environment interaction where you want consistent tuple output:
This is particularly useful for environment interaction where you want consistent dictionary output:
```python
from lerobot.processor.observation_processor import VanillaObservationProcessor
@@ -332,7 +344,7 @@ from lerobot.processor.observation_processor import VanillaObservationProcessor
env_processor = RobotProcessor(
[VanillaObservationProcessor()],
to_transition=lambda x: x, # Pass through - no conversion needed
to_output=lambda x: x, # Always return EnvTransition tuple
to_output=lambda x: x, # Always return EnvTransition dictionary
)
# Environment interaction loop
@@ -340,12 +352,20 @@ env = make_env()
obs, info = env.reset()
for step in range(1000):
# Create transition - input is already in tuple format
transition = (obs, None, 0.0, False, False, info, {"step": step})
# Create transition - input is already in dictionary format
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: None,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: {"step": step}
}
# Process - output is guaranteed to be EnvTransition tuple
# Process - output is guaranteed to be EnvTransition dictionary
processed_transition = env_processor(transition)
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
processed_obs = processed_transition[TransitionKey.OBSERVATION]
# Use with policy
action = policy.select_action(processed_obs)
@@ -357,12 +377,12 @@ for step in range(1000):
### When to Use Different Output Formats
**Use EnvTransition tuple output when:**
**Use EnvTransition dictionary output when:**
- Environment interaction and real-time control
- You need to access individual transition components frequently
- Performance is critical (avoids dictionary creation overhead)
- Working with gym environments that expect tuple format
- Working with gym environments that expect structured data
- You need the flexibility of dictionary operations
**Use batch dictionary output when:**
@@ -372,10 +392,10 @@ for step in range(1000):
- You need the standardized "next.\*" key format
```python
# For environment interaction - use tuple output
# For environment interaction - use dictionary output
env_processor = RobotProcessor(
steps=[ImageProcessor(), StateProcessor()],
to_output=lambda x: x # Return EnvTransition tuple
to_output=lambda x: x # Return EnvTransition dictionary
)
# For training - use batch output (default)
@@ -391,9 +411,17 @@ for batch in dataloader:
# Environment loop
for step in range(1000):
transition = (obs, None, 0.0, False, False, info, {})
processed_transition = env_processor(transition) # Returns EnvTransition tuple
obs = processed_transition[TransitionIndex.OBSERVATION]
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: None,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: {}
}
processed_transition = env_processor(transition) # Returns EnvTransition dictionary
obs = processed_transition[TransitionKey.OBSERVATION]
action = policy.select_action(obs)
```
@@ -426,7 +454,7 @@ batch = {
Let's create a processor that properly handles image and state preprocessing:
```python
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
from lerobot.processor.observation_processor import ImageProcessor, StateProcessor
import numpy as np
@@ -440,7 +468,15 @@ observation = {
}
# Create a full transition
transition = (observation, None, 0.0, False, False, {}, {})
transition = {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: None,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {}
}
# Create and use the processor
processor = RobotProcessor([
@@ -449,7 +485,7 @@ processor = RobotProcessor([
])
processed_transition = processor(transition)
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
processed_obs = processed_transition[TransitionKey.OBSERVATION]
# Check the results
print("Original keys:", observation.keys())
@@ -541,11 +577,19 @@ obs, info = env.reset()
for step in range(1000):
# Raw environment observation
transition = (obs, None, 0.0, False, False, info, {})
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: None,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: {}
}
# Process for policy input
processed_transition = online_processor(transition)
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
processed_obs = processed_transition[TransitionKey.OBSERVATION]
# Get action from policy
action = policy.select_action(processed_obs)
@@ -585,15 +629,16 @@ class ImagePadder:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Main processing method - required for all steps."""
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if obs is None:
return transition
# Process all image observations
for key in list(obs.keys()):
processed_obs = dict(obs) # Create a copy
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = obs[key]
img = processed_obs[key]
# Calculate padding
_, _, h, w = img.shape
pad_h = max(0, self.target_height - h)
@@ -609,10 +654,12 @@ class ImagePadder:
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
mode='constant', value=self.pad_value)
obs[key] = img
processed_obs[key] = img
# Return modified transition
return (obs, *transition[1:])
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> Dict[str, Any]:
"""Return JSON-serializable configuration - required for save/load."""
@@ -694,8 +741,8 @@ class ImageStatisticsCalculator:
"""Calculate image statistics and pass to next steps."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
obs = transition.get(TransitionKey.OBSERVATION)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if obs is None:
return transition
@@ -714,18 +761,13 @@ class ImageStatisticsCalculator:
image_stats[key] = stats
# Store in complementary_data for next steps
comp_data = dict(comp_data) # Make a copy
comp_data["image_statistics"] = image_stats
# Return transition with updated complementary_data
return (
obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
comp_data # Updated complementary_data
)
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
@dataclass
class AdaptiveBrightnessAdjuster:
@@ -734,8 +776,8 @@ class AdaptiveBrightnessAdjuster:
target_brightness: float = 0.5
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
obs = transition.get(TransitionKey.OBSERVATION)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if obs is None or "image_statistics" not in comp_data:
return transition
@@ -743,15 +785,18 @@ class AdaptiveBrightnessAdjuster:
# Use statistics from previous step
image_stats = comp_data["image_statistics"]
for key in obs:
processed_obs = dict(obs) # Create a copy
for key in processed_obs:
if key.startswith("observation.images.") and key in image_stats:
current_mean = image_stats[key]["mean"]
brightness_adjust = self.target_brightness - current_mean
# Adjust brightness
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
processed_obs[key] = torch.clamp(processed_obs[key] + brightness_adjust, 0, 1)
return (obs, *transition[1:])
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
# Use them together
processor = RobotProcessor([
@@ -782,7 +827,7 @@ class ActionRepeatStep:
env: gym.Env = None # This can't be serialized to JSON!
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
action = transition.get(TransitionKey.ACTION)
if self.env is not None and action is not None:
# Repeat action multiple times in environment
@@ -792,9 +837,13 @@ class ActionRepeatStep:
total_reward += r
if d or t:
break
reward = total_reward
return (obs, action, reward, done, truncated, info, comp_data)
# Update reward in transition
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = total_reward
return new_transition
return transition
def get_config(self) -> dict[str, Any]:
# Note: env is NOT included because it's not serializable
@@ -1211,7 +1260,7 @@ This enables sharing of preprocessing logic while allowing each user to provide
Here's a complete example showing proper device management and all features:
```python
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey
import torch
import torch.nn.functional as F
import numpy as np
@@ -1224,23 +1273,29 @@ class DeviceMover:
device: str = "cuda"
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if obs is None:
return transition
# Move all tensor observations to device
for key, value in obs.items():
processed_obs = dict(obs) # Create a copy
for key, value in processed_obs.items():
if isinstance(value, torch.Tensor):
obs[key] = value.to(self.device)
processed_obs[key] = value.to(self.device)
# Also handle action if present
action = transition[TransitionIndex.ACTION]
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
action = action.to(self.device)
return (obs, action, *transition[2:])
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
new_transition[TransitionKey.ACTION] = action
return new_transition
return (obs, *transition[1:])
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> Dict[str, Any]:
return {"device": str(self.device)}
@@ -1260,7 +1315,7 @@ class RunningNormalizer:
self.initialized = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if obs is None or "observation.state" not in obs:
return transition
@@ -1284,9 +1339,12 @@ class RunningNormalizer:
# Normalize
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
obs["observation.state"] = state_normalized
processed_obs = dict(obs) # Create a copy
processed_obs["observation.state"] = state_normalized
return (obs, *transition[1:])
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> Dict[str, Any]:
return {
@@ -1317,7 +1375,7 @@ class RunningNormalizer:
processor = RobotProcessor([
ImageProcessor(), # Convert images to float32 [0,1]
StateProcessor(), # Convert states to torch tensors
ImagePadder(224, 224), # Pad images to standard size
ImagePadder(target_height=224, target_width=224), # Pad images to standard size
DeviceMover("cuda"), # Move everything to GPU
RunningNormalizer(7), # Normalize states
], name="CompletePreprocessor")
@@ -1330,11 +1388,19 @@ obs = {
"pixels": {"cam": np.random.randint(0, 255, (200, 300, 3), dtype=np.uint8)},
"agent_pos": np.random.randn(7).astype(np.float32)
}
transition = (obs, None, 0.0, False, False, {}, {})
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: None,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {},
TransitionKey.COMPLEMENTARY_DATA: {}
}
# Everything is processed and on GPU
processed = processor(transition)
print(processed[TransitionIndex.OBSERVATION]["observation.images.cam"].device) # cuda:0
print(processed[TransitionKey.OBSERVATION]["observation.images.cam"].device) # cuda:0
```
## Solving Real-World Problems with RobotProcessor
@@ -1358,22 +1424,25 @@ class KeyRemapper:
})
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if obs is None:
return transition
# Create new observation with renamed keys
processed_obs = dict(obs) # Create a copy
renamed_obs = {}
for old_key, new_key in self.key_mapping.items():
if old_key in obs:
renamed_obs[new_key] = obs[old_key]
if old_key in processed_obs:
renamed_obs[new_key] = processed_obs[old_key]
# Keep any unmapped keys as-is
for key, value in obs.items():
for key, value in processed_obs.items():
if key not in self.key_mapping:
renamed_obs[key] = value
return (renamed_obs, *transition[1:])
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = renamed_obs
return new_transition
```
### Workspace-Focused Image Processing
@@ -1390,13 +1459,14 @@ class WorkspaceCropper:
output_size: Tuple[int, int] = (224, 224)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if obs is None:
return transition
for key in list(obs.keys()):
processed_obs = dict(obs) # Create a copy
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = obs[key]
img = processed_obs[key]
# Crop to workspace
x1, y1, x2, y2 = self.crop_bbox
img_cropped = img[:, :, y1:y2, x1:x2]
@@ -1407,9 +1477,11 @@ class WorkspaceCropper:
mode='bilinear',
align_corners=False
)
obs[key] = img_resized
processed_obs[key] = img_resized
return (obs, *transition[1:])
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
```
### Building Complete Pipelines for Different Robots
@@ -1471,7 +1543,7 @@ The beauty of this approach is that:
```python
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
# Always check if observation exists
if obs is None:
@@ -1496,7 +1568,7 @@ return (modified_obs, None, 0.0, False, False, {}, {})
```python
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if self.store_previous:
# Good - clone to avoid reference issues
@@ -1522,7 +1594,7 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
Here's how to use RobotProcessor in a real robot control loop, showing both tuple and batch formats:
```python
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey
from lerobot.policies.act.modeling_act import ACTPolicy
from pathlib import Path
import time
@@ -1545,11 +1617,13 @@ class ActionClipper:
max_value: float = 1.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition[TransitionIndex.ACTION]
action = transition.get(TransitionKey.ACTION)
if action is not None:
action = torch.clamp(action, self.min_value, self.max_value)
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = action
return new_transition
return transition
@@ -1578,28 +1652,36 @@ for episode in range(10):
for step in range(1000):
# Create transition with raw observation
transition = (obs, None, 0.0, False, False, info, {"step": step})
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: None,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: {"step": step}
}
# Preprocess - works with tuple format
# Preprocess - works with dictionary format
processed_transition = preprocessor(transition)
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
processed_obs = processed_transition.get(TransitionKey.OBSERVATION)
# Get action from policy
with torch.no_grad():
action = policy.select_action(processed_obs)
# Postprocess action
action_transition = (
processed_obs,
action,
0.0,
False,
False,
info,
{"raw_action": action.clone()} # Store raw action in complementary_data
)
action_transition = {
TransitionKey.OBSERVATION: processed_obs,
TransitionKey.ACTION: action,
TransitionKey.REWARD: 0.0,
TransitionKey.DONE: False,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: {"raw_action": action.clone()} # Store raw action in complementary_data
}
processed_action_transition = postprocessor(action_transition)
final_action = processed_action_transition[TransitionIndex.ACTION]
final_action = processed_action_transition.get(TransitionKey.ACTION)
# Execute action
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
@@ -1667,7 +1749,7 @@ Use the full power of `RobotProcessor` for debugging:
```python
# Enable detailed logging
def log_observation_shapes(step_idx: int, transition: EnvTransition):
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if obs:
print(f"Step {step_idx} observations:")
for key, value in obs.items():
@@ -1679,7 +1761,7 @@ processor.register_after_step_hook(log_observation_shapes)
# Monitor complementary data flow
def monitor_complementary_data(step_idx: int, transition: EnvTransition):
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if comp_data:
print(f"Step {step_idx} complementary_data: {list(comp_data.keys())}")
return None
@@ -1688,7 +1770,7 @@ processor.register_before_step_hook(monitor_complementary_data)
# Validate data integrity
def validate_tensors(step_idx: int, transition: EnvTransition):
obs = transition[TransitionIndex.OBSERVATION]
obs = transition.get(TransitionKey.OBSERVATION)
if obs:
for key, value in obs.items():
if isinstance(value, torch.Tensor):
@@ -1705,21 +1787,21 @@ processor.register_after_step_hook(validate_tensors)
RobotProcessor provides a powerful, modular approach to data preprocessing in robotics:
- **Dual format support**: Works seamlessly with both EnvTransition tuples and batch dictionaries
- **Automatic format conversion**: Converts between tuple and batch formats as needed
- **Dual format support**: Works seamlessly with both EnvTransition dictionaries and batch dictionaries
- **Automatic format conversion**: Converts between dictionary and batch formats as needed
- **LeRobot integration**: Native support for LeRobotDataset and ReplayBuffer formats
- **Clear separation of concerns**: Each transformation is a separate, testable unit
- **Proper state management**: Clear distinction between config (JSON) and state (tensors)
- **Device-aware**: Seamless GPU/CPU transfers with `.to(device)`
- **Inter-step communication**: Use `complementary_data` for passing information
- **Easy sharing**: Push to Hugging Face Hub for reproducibility
- **Type safety**: Use `TransitionIndex` instead of magic numbers
- **Type safety**: Use `TransitionKey` instead of magic numbers
- **Debugging tools**: Step through transformations and add monitoring hooks
- **Flexible conversion**: Customize `to_transition` and `to_output` functions for specific needs
Key advantages of the dual format approach:
- **Environment interaction**: Use tuple format for real-time robot control
- **Environment interaction**: Use dictionary format for real-time robot control
- **Training/evaluation**: Use batch format for dataset processing and model training
- **Seamless integration**: Same processor works with both formats automatically
- **Backward compatibility**: Existing code using either format continues to work
+14 -6
View File
@@ -36,17 +36,25 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
# Create processor with observation processor
processor = RobotProcessor([VanillaObservationProcessor()])
# Create transition tuple and process
transition = (observations, None, None, None, None, None, None)
processed_transition = processor(transition)
# Create transition dictionary and process
transition = {
TransitionKey.OBSERVATION: observations,
TransitionKey.ACTION: None,
TransitionKey.REWARD: None,
TransitionKey.DONE: None,
TransitionKey.TRUNCATED: None,
TransitionKey.INFO: None,
TransitionKey.COMPLEMENTARY_DATA: None,
}
result = processor(transition)
# Return processed observations
return processed_transition[TransitionIndex.OBSERVATION]
# Extract and return the processed observation
return result[TransitionKey.OBSERVATION]
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
+2 -2
View File
@@ -32,7 +32,7 @@ from .pipeline import (
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionIndex,
TransitionKey,
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
@@ -54,7 +54,7 @@ __all__ = [
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TransitionIndex",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
]
+32 -21
View File
@@ -18,7 +18,7 @@ from typing import Any
import torch
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, TransitionKey
@dataclass
@@ -35,30 +35,41 @@ class DeviceProcessor:
self.non_blocking = "cuda" in self.device
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
action = transition[TransitionIndex.ACTION]
reward = transition[TransitionIndex.REWARD]
done = transition[TransitionIndex.DONE]
truncated = transition[TransitionIndex.TRUNCATED]
info = transition[TransitionIndex.INFO]
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
# Create a copy of the transition
new_transition = transition.copy()
# Process observation tensors
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
observation = {
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
new_observation = {
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
if action is not None:
action = action.to(self.device)
new_transition[TransitionKey.OBSERVATION] = new_observation
return (
observation,
action,
reward,
done,
truncated,
info,
complementary_data,
)
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
+18 -24
View File
@@ -1,8 +1,8 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
from collections.abc import Mapping
import numpy as np
import torch
@@ -10,7 +10,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
@@ -166,17 +166,14 @@ class NormalizerProcessor:
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition[TransitionIndex.OBSERVATION])
action = self._normalize_action(transition[TransitionIndex.ACTION])
return (
observation,
action,
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
config = {
@@ -297,17 +294,14 @@ class UnnormalizerProcessor:
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition[TransitionIndex.OBSERVATION])
action = self._unnormalize_action(transition[TransitionIndex.ACTION])
return (
observation,
action,
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
return {
+9 -21
View File
@@ -21,7 +21,7 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@@ -36,7 +36,7 @@ class ImageProcessor:
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -60,15 +60,9 @@ class ImageProcessor:
processed_obs[key] = value
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
@@ -124,7 +118,7 @@ class StateProcessor:
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -150,15 +144,9 @@ class StateProcessor:
del processed_obs["agent_pos"]
# Return new transition with processed observation
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
+132 -152
View File
@@ -18,39 +18,42 @@ from __future__ import annotations
import importlib
import json
import os
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Any, Protocol
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Protocol, TypedDict
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
class TransitionIndex(IntEnum):
"""Explicit indices for EnvTransition tuple components."""
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
OBSERVATION = 0
ACTION = 1
REWARD = 2
DONE = 3
TRUNCATED = 4
INFO = 5
COMPLEMENTARY_DATA = 6
OBSERVATION = "observation"
ACTION = "action"
REWARD = "reward"
DONE = "done"
TRUNCATED = "truncated"
INFO = "info"
COMPLEMENTARY_DATA = "complementary_data"
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = tuple[
dict[str, Any] | None, # observation
Any | torch.Tensor | None, # action
float | torch.Tensor | None, # reward
bool | torch.Tensor | None, # done
bool | torch.Tensor | None, # truncated
dict[str, Any] | None, # info
dict[str, Any] | None, # complementary_data
]
class EnvTransition(TypedDict, total=False):
"""Environment transition data structure.
All fields are optional (total=False) to allow flexible usage.
"""
observation: dict[str, Any] | None
action: Any | torch.Tensor | None
reward: float | torch.Tensor | None
done: bool | torch.Tensor | None
truncated: bool | torch.Tensor | None
info: dict[str, Any] | None
complementary_data: dict[str, Any] | None
class ProcessorStepRegistry:
@@ -165,10 +168,9 @@ class ProcessorStep(Protocol):
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
``EnvTransition`` tuple.
``EnvTransition`` dictionary.
The function is intentionally **strictly positional** it maps well known
keys to the fixed slot order used inside the pipeline. Missing keys are
The function maps well known keys to the EnvTransition structure. Missing keys are
filled with sane defaults (``None`` or ``0.0``/``False``).
Keys recognised (case-sensitive):
@@ -193,15 +195,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
task_key = {"task": batch["task"]} if "task" in batch else {}
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
return (
observation,
batch.get("action"),
batch.get("next.reward", 0.0),
batch.get("next.done", False),
batch.get("next.truncated", False),
batch.get("info", {}),
complementary_data,
)
transition: EnvTransition = {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: batch.get("action"),
TransitionKey.REWARD: batch.get("next.reward", 0.0),
TransitionKey.DONE: batch.get("next.done", False),
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
TransitionKey.INFO: batch.get("info", {}),
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
return transition
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
@@ -209,25 +212,16 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
the canonical field names used throughout *LeRobot*.
"""
(
observation,
action,
reward,
done,
truncated,
info,
complementary_data,
) = transition
batch = {
"action": action,
"next.reward": reward,
"next.done": done,
"next.truncated": truncated,
"info": info,
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}
# Add padding and task data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
batch.update(pad_data)
@@ -236,6 +230,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
batch["task"] = complementary_data["task"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
@@ -293,33 +288,35 @@ class RobotProcessor(ModelHubMixin):
def __call__(self, data: EnvTransition | dict[str, Any]):
"""Process data through all steps.
The method accepts either the classic EnvTransition tuple or a batch dictionary
The method accepts either the classic EnvTransition dict or a batch dictionary
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
it is first converted to the internal tuple format using to_transition; after all
steps are executed the tuple is transformed back into a dict with to_batch and the
it is first converted to the internal dict format using to_transition; after all
steps are executed the dict is transformed back into a batch dict with to_batch and the
result is returned thereby preserving the caller's original data type.
Args:
data: Either an EnvTransition tuple or a batch dictionary to process.
data: Either an EnvTransition dict or a batch dictionary to process.
Returns:
The processed data in the same format as the input (tuple or dict).
The processed data in the same format as the input (EnvTransition or batch dict).
Raises:
ValueError: If the transition is not a valid 7-tuple format.
ValueError: If the transition is not a valid EnvTransition format.
"""
called_with_batch = isinstance(data, dict)
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
# It's a batch dict, convert it
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
transition = self.to_transition(data) if called_with_batch else data
# Basic validation with helpful error message for tuple input
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
"truncated, info, complementary_data). "
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
)
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
for idx, processor_step in enumerate(self.steps):
for hook in self.before_step_hooks:
@@ -339,25 +336,28 @@ class RobotProcessor(ModelHubMixin):
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
"""Yield the intermediate results after each processor step.
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
and preserves the input format in the yielded results.
Args:
data: Either an EnvTransition tuple or a batch dictionary to process.
data: Either an EnvTransition dict or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
"""
called_with_batch = isinstance(data, dict)
transition = self.to_transition(data) if called_with_batch else data
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
# It's a batch dict, convert it
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
# Basic validation with helpful error message for tuple input
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
"truncated, info, complementary_data). "
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
)
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
# Yield initial state
yield self.to_output(transition) if called_with_batch else transition
@@ -684,7 +684,7 @@ class ObservationProcessor:
Subclasses should override the `observation` method to implement custom observation processing.
This class handles the boilerplate of extracting and reinserting the processed observation
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -696,7 +696,7 @@ class ObservationProcessor:
return observation * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific observation processing logic.
"""
@@ -712,10 +712,12 @@ class ObservationProcessor:
return observation
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = self.observation(observation)
transition = (observation, *transition[TransitionIndex.ACTION :])
return transition
observation = transition.get(TransitionKey.OBSERVATION)
processed_observation = self.observation(observation)
# Create a new transition dict with the processed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_observation
return new_transition
class ActionProcessor:
@@ -723,7 +725,7 @@ class ActionProcessor:
Subclasses should override the `action` method to implement custom action processing.
This class handles the boilerplate of extracting and reinserting the processed action
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -736,7 +738,7 @@ class ActionProcessor:
return np.clip(action, self.min_val, self.max_val)
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific action processing logic.
"""
@@ -752,10 +754,12 @@ class ActionProcessor:
return action
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition[TransitionIndex.ACTION]
action = self.action(action)
transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :])
return transition
action = transition.get(TransitionKey.ACTION)
processed_action = self.action(action)
# Create a new transition dict with the processed action
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = processed_action
return new_transition
class RewardProcessor:
@@ -763,7 +767,7 @@ class RewardProcessor:
Subclasses should override the `reward` method to implement custom reward processing.
This class handles the boilerplate of extracting and reinserting the processed reward
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -775,7 +779,7 @@ class RewardProcessor:
return reward * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific reward processing logic.
"""
@@ -791,15 +795,12 @@ class RewardProcessor:
return reward
def __call__(self, transition: EnvTransition) -> EnvTransition:
reward = transition[TransitionIndex.REWARD]
reward = self.reward(reward)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
reward,
*transition[TransitionIndex.DONE :],
)
return transition
reward = transition.get(TransitionKey.REWARD)
processed_reward = self.reward(reward)
# Create a new transition dict with the processed reward
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = processed_reward
return new_transition
class DoneProcessor:
@@ -807,7 +808,7 @@ class DoneProcessor:
Subclasses should override the `done` method to implement custom done flag processing.
This class handles the boilerplate of extracting and reinserting the processed done flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -824,7 +825,7 @@ class DoneProcessor:
self.steps = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific done flag processing logic.
"""
@@ -840,16 +841,12 @@ class DoneProcessor:
return done
def __call__(self, transition: EnvTransition) -> EnvTransition:
done = transition[TransitionIndex.DONE]
done = self.done(done)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
done,
*transition[TransitionIndex.TRUNCATED :],
)
return transition
done = transition.get(TransitionKey.DONE)
processed_done = self.done(done)
# Create a new transition dict with the processed done flag
new_transition = transition.copy()
new_transition[TransitionKey.DONE] = processed_done
return new_transition
class TruncatedProcessor:
@@ -857,7 +854,7 @@ class TruncatedProcessor:
Subclasses should override the `truncated` method to implement custom truncated flag processing.
This class handles the boilerplate of extracting and reinserting the processed truncated flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -870,7 +867,7 @@ class TruncatedProcessor:
return truncated or some_condition > self.threshold
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific truncated flag processing logic.
"""
@@ -886,17 +883,12 @@ class TruncatedProcessor:
return truncated
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition[TransitionIndex.TRUNCATED]
truncated = self.truncated(truncated)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
truncated,
*transition[TransitionIndex.INFO :],
)
return transition
truncated = transition.get(TransitionKey.TRUNCATED)
processed_truncated = self.truncated(truncated)
# Create a new transition dict with the processed truncated flag
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = processed_truncated
return new_transition
class InfoProcessor:
@@ -904,7 +896,7 @@ class InfoProcessor:
Subclasses should override the `info` method to implement custom info processing.
This class handles the boilerplate of extracting and reinserting the processed info
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
@@ -922,7 +914,7 @@ class InfoProcessor:
self.step_count = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific info dictionary processing logic.
"""
@@ -938,18 +930,12 @@ class InfoProcessor:
return info
def __call__(self, transition: EnvTransition) -> EnvTransition:
info = transition[TransitionIndex.INFO]
info = self.info(info)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
info,
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
)
return transition
info = transition.get(TransitionKey.INFO)
processed_info = self.info(info)
# Create a new transition dict with the processed info
new_transition = transition.copy()
new_transition[TransitionKey.INFO] = processed_info
return new_transition
class ComplementaryDataProcessor:
@@ -957,7 +943,7 @@ class ComplementaryDataProcessor:
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
This class handles the boilerplate of extracting and reinserting the processed complementary data
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
"""
def complementary_data(self, complementary_data):
@@ -972,18 +958,12 @@ class ComplementaryDataProcessor:
return complementary_data
def __call__(self, transition: EnvTransition) -> EnvTransition:
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
complementary_data = self.complementary_data(complementary_data)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
complementary_data,
)
return transition
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
processed_complementary_data = self.complementary_data(complementary_data)
# Create a new transition dict with the processed complementary data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
return new_transition
class IdentityProcessor:
+7 -11
View File
@@ -18,7 +18,7 @@ from typing import Any
import torch
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@@ -29,7 +29,7 @@ class RenameProcessor:
rename_map: dict[str, str] = field(default_factory=dict)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
@@ -39,15 +39,11 @@ class RenameProcessor:
processed_obs[self.rename_map[key]] = value
else:
processed_obs[key] = value
return (
processed_obs,
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
# Create a new transition with the renamed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
+3 -3
View File
@@ -72,7 +72,7 @@ from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -160,7 +160,7 @@ def rollout(
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]
observation = processed_transition[TransitionKey.OBSERVATION]
if return_observations:
all_observations.append(deepcopy(observation))
@@ -211,7 +211,7 @@ def rollout(
if return_observations:
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]
observation = processed_transition[TransitionKey.OBSERVATION]
all_observations.append(deepcopy(observation))
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
+2 -2
View File
@@ -22,7 +22,7 @@ from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.envs.factory import make_env, make_env_config
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
@@ -53,7 +53,7 @@ def test_factory(env_name):
obs_processor = RobotProcessor([VanillaObservationProcessor()])
transition = (obs, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
obs = processed_transition[TransitionIndex.OBSERVATION]
obs = processed_transition[TransitionKey.OBSERVATION]
# test image keys are float32 in range [0,1]
for key in obs:
+2 -2
View File
@@ -39,7 +39,7 @@ from lerobot.policies.factory import (
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -188,7 +188,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
obs_processor = RobotProcessor([VanillaObservationProcessor()])
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]
observation = processed_transition[TransitionKey.OBSERVATION]
# send observation to device/gpu
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
+59 -65
View File
@@ -2,7 +2,7 @@ import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionIndex,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
@@ -63,27 +63,27 @@ def test_batch_to_transition_observation_grouping():
transition = _default_batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionIndex.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionIndex.OBSERVATION]
assert "observation.image.left" in transition[TransitionIndex.OBSERVATION]
assert "observation.state" in transition[TransitionIndex.OBSERVATION]
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
assert "observation.state" in transition[TransitionKey.OBSERVATION]
# Check values are preserved
assert torch.allclose(
transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
)
assert torch.allclose(
transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
)
assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields
assert transition[TransitionIndex.ACTION] == "action_data"
assert transition[TransitionIndex.REWARD] == 1.5
assert transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {"episode": 42}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 1.5
assert transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"episode": 42}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_transition_to_batch_observation_flattening():
@@ -94,15 +94,15 @@ def test_transition_to_batch_observation_flattening():
"observation.state": [1, 2, 3, 4],
}
transition = (
observation_dict, # observation
"action_data", # action
1.5, # reward
True, # done
False, # truncated
{"episode": 42}, # info
{}, # complementary_data
)
transition = {
TransitionKey.OBSERVATION: observation_dict,
TransitionKey.ACTION: "action_data",
TransitionKey.REWARD: 1.5,
TransitionKey.DONE: True,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {"episode": 42},
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
@@ -137,14 +137,14 @@ def test_no_observation_keys():
transition = _default_batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionIndex.OBSERVATION] is None
assert transition[TransitionKey.OBSERVATION] is None
# Check other fields
assert transition[TransitionIndex.ACTION] == "action_data"
assert transition[TransitionIndex.REWARD] == 2.0
assert not transition[TransitionIndex.DONE]
assert transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {"test": "no_obs"}
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 2.0
assert not transition[TransitionKey.DONE]
assert transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
@@ -162,15 +162,15 @@ def test_minimal_batch():
transition = _default_batch_to_transition(batch)
# Check observation
assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionIndex.ACTION] == "minimal_action"
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionKey.ACTION] == "minimal_action"
# Check defaults
assert transition[TransitionIndex.REWARD] == 0.0
assert not transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
@@ -189,13 +189,13 @@ def test_empty_batch():
transition = _default_batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionIndex.OBSERVATION] is None
assert transition[TransitionIndex.ACTION] is None
assert transition[TransitionIndex.REWARD] == 0.0
assert not transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
assert transition[TransitionKey.OBSERVATION] is None
assert transition[TransitionKey.ACTION] is None
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
@@ -256,33 +256,27 @@ def test_custom_converter():
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
# Double the reward
reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0
return (
tr[TransitionIndex.OBSERVATION],
tr[TransitionIndex.ACTION],
reward,
tr[TransitionIndex.DONE],
tr[TransitionIndex.TRUNCATED],
tr[TransitionIndex.INFO],
tr[TransitionIndex.COMPLEMENTARY_DATA],
)
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
return new_tr
def to_batch(tr):
# Custom converter that adds a custom field
batch = _default_transition_to_batch(tr)
batch["custom_field"] = "custom_value"
return batch
proc = RobotProcessor([], to_transition=to_tr, to_output=to_batch)
batch = _dummy_batch()
out = proc(batch)
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
# Check that custom modifications were applied
assert out["next.reward"] == batch["next.reward"] * 2
assert out["custom_field"] == "custom_value"
batch = {
"observation.state": torch.randn(1, 4),
"action": torch.randn(1, 2),
"next.reward": 1.0,
"next.done": False,
}
# Check that observation.* keys are still preserved
original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")}
result = processor(batch)
assert set(original_obs_keys.keys()) == set(output_obs_keys.keys())
# Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0
assert torch.allclose(result["observation.state"], batch["observation.state"])
assert torch.allclose(result["action"], batch["action"])
+94 -35
View File
@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
import numpy as np
@@ -10,7 +25,22 @@ from lerobot.processor.normalize_processor import (
UnnormalizerProcessor,
_convert_stats_to_tensors,
)
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_numpy_conversion():
@@ -120,10 +150,10 @@ def test_mean_std_normalization(observation_normalizer):
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check mean/std normalization
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
@@ -134,10 +164,10 @@ def test_min_max_normalization(observation_normalizer):
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check min/max normalization to [-1, 1]
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
@@ -157,10 +187,10 @@ def test_selective_normalization(observation_stats):
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Only image should be normalized
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
@@ -176,10 +206,10 @@ def test_device_compatibility(observation_stats):
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
assert normalized_obs["observation.image"].device.type == "cuda"
@@ -220,10 +250,10 @@ def test_state_dict_save_load(observation_normalizer):
# Test that it works the same
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result1 = observation_normalizer(transition)[0]
result2 = new_normalizer(transition)[0]
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(result1["observation.image"], result2["observation.image"])
@@ -271,10 +301,10 @@ def test_mean_std_unnormalization(action_stats_mean_std):
)
normalized_action = torch.tensor([1.0, -0.5, 2.0])
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# action * std + mean
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
@@ -290,10 +320,10 @@ def test_min_max_unnormalization(action_stats_min_max):
# Actions in [-1, 1]
normalized_action = torch.tensor([0.0, -1.0, 1.0])
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# Map from [-1, 1] to [min, max]
# (action + 1) / 2 * (max - min) + min
@@ -315,10 +345,10 @@ def test_numpy_action_input(action_stats_mean_std):
)
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
assert isinstance(unnormalized_action, torch.Tensor)
expected = torch.tensor([1.0, -1.0, 1.0])
@@ -332,7 +362,7 @@ def test_none_action(action_stats_mean_std):
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = unnormalizer(transition)
# Should return transition unchanged
@@ -396,23 +426,31 @@ def test_combined_normalization(normalizer_processor):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = normalizer_processor(transition)
# Check normalized observations
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
processed_obs = processed_transition[TransitionKey.OBSERVATION]
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(processed_obs["observation.image"], expected_image)
# Check normalized action
processed_action = processed_transition[TransitionIndex.ACTION]
processed_action = processed_transition[TransitionKey.ACTION]
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
assert torch.allclose(processed_action, expected_action)
# Check other fields remain unchanged
assert processed_transition[TransitionIndex.REWARD] == 1.0
assert not processed_transition[TransitionIndex.DONE]
assert processed_transition[TransitionKey.REWARD] == 1.0
assert not processed_transition[TransitionKey.DONE]
def test_processor_from_lerobot_dataset(full_stats):
@@ -466,13 +504,21 @@ def test_integration_with_robot_processor(normalizer_processor):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = robot_processor(transition)
# Verify the processing worked
assert isinstance(processed_transition[TransitionIndex.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionIndex.ACTION], torch.Tensor)
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
# Edge case tests
@@ -482,7 +528,7 @@ def test_empty_observation():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = normalizer(transition)
assert result == transition
@@ -493,11 +539,13 @@ def test_empty_stats():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
observation = {"observation.image": torch.tensor([0.5])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = normalizer(transition)
# Should return observation unchanged since no stats are available
assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
assert torch.allclose(
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
)
def test_partial_stats():
@@ -507,9 +555,9 @@ def test_partial_stats():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionIndex.OBSERVATION]
processed = normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(processed["observation.image"], observation["observation.image"])
@@ -551,14 +599,25 @@ def test_serialization_roundtrip(full_stats):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
result1 = original_processor(transition)
result2 = new_processor(transition)
# Compare results
assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"])
assert torch.allclose(result1[1], result2[1])
assert torch.allclose(
result1[TransitionKey.OBSERVATION]["observation.image"],
result2[TransitionKey.OBSERVATION]["observation.image"],
)
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
# Verify features and norm_map are correctly reconstructed
assert new_processor.features.keys() == original_processor.features.keys()
+49 -33
View File
@@ -23,6 +23,22 @@ from lerobot.processor import (
StateProcessor,
VanillaObservationProcessor,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_process_single_image():
@@ -33,10 +49,10 @@ def test_process_single_image():
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
@@ -60,10 +76,10 @@ def test_process_image_dict():
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {"pixels": {"camera1": image1, "camera2": image2}}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
@@ -82,10 +98,10 @@ def test_process_batched_image():
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
@@ -98,7 +114,7 @@ def test_invalid_image_format():
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
@@ -111,7 +127,7 @@ def test_invalid_image_dtype():
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
@@ -122,10 +138,10 @@ def test_no_pixels_in_observation():
processor = ImageProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve other data unchanged
assert "other_data" in processed_obs
@@ -136,7 +152,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = processor(transition)
assert result == transition
@@ -167,10 +183,10 @@ def test_process_environment_state():
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
@@ -188,10 +204,10 @@ def test_process_agent_pos():
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
@@ -211,10 +227,10 @@ def test_process_batched_states():
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
@@ -229,10 +245,10 @@ def test_process_both_states():
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
@@ -251,10 +267,10 @@ def test_no_states_in_observation():
processor = StateProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve data unchanged
np.testing.assert_array_equal(processed_obs, observation)
@@ -275,10 +291,10 @@ def test_complete_observation_processing():
"agent_pos": agent_pos,
"other_data": "preserve_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that image was processed
assert "observation.image" in processed_obs
@@ -303,10 +319,10 @@ def test_image_only_processing():
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
@@ -318,10 +334,10 @@ def test_state_only_processing():
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
@@ -332,10 +348,10 @@ def test_empty_observation():
processor = VanillaObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert processed_obs == {}
@@ -369,8 +385,8 @@ def test_equivalent_to_original_function():
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
@@ -396,8 +412,8 @@ def test_equivalent_with_image_dict():
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
+107 -73
View File
@@ -18,7 +18,7 @@ import json
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from typing import Any
import numpy as np
import pytest
@@ -26,6 +26,22 @@ import torch
import torch.nn as nn
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor.pipeline import TransitionKey
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
@dataclass
@@ -45,14 +61,16 @@ class MockStep:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add a counter to the complementary_data."""
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data) # Make a copy
comp_data[f"{self.name}_counter"] = self.counter
self.counter += 1
return (obs, action, reward, done, truncated, info, comp_data)
# Create a new transition with updated complementary_data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
# Return all JSON-serializable attributes that should be persisted
@@ -79,12 +97,14 @@ class MockStepWithoutOptionalMethods:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Multiply reward by multiplier."""
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
if reward is not None:
reward = reward * self.multiplier
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward * self.multiplier
return new_transition
return (obs, action, reward, done, truncated, info, comp_data)
return transition
@dataclass
@@ -105,7 +125,7 @@ class MockStepWithTensorState:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Update running statistics."""
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
if reward is not None:
# Update running mean
@@ -143,7 +163,7 @@ def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotProcessor()
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert result == transition
@@ -155,15 +175,15 @@ def test_single_step_pipeline():
step = MockStep("test_step")
pipeline = RobotProcessor([step])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert len(pipeline) == 1
assert result[6]["test_step_counter"] == 0 # complementary_data
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
# Call again to test counter increment
result = pipeline(transition)
assert result[6]["test_step_counter"] == 1
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1
def test_multiple_steps_pipeline():
@@ -172,46 +192,46 @@ def test_multiple_steps_pipeline():
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert len(pipeline) == 2
assert result[6]["step1_counter"] == 0
assert result[6]["step2_counter"] == 0
assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0
assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0
def test_invalid_transition_format():
"""Test pipeline with invalid transition format."""
pipeline = RobotProcessor([MockStep()])
# Test with wrong number of elements
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline((None, None, 0.0)) # Only 3 elements
# Test with wrong type (tuple instead of dict)
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict
# Test with wrong type
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline("not a tuple")
# Test with wrong type (string)
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
pipeline("not a dict")
def test_step_through():
"""Test step_through method with tuple input."""
"""Test step_through method with dict input."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
results = list(pipeline.step_through(transition))
assert len(results) == 3 # Original + 2 steps
assert results[0] == transition # Original
assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2
assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1
assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2
# Ensure all results are tuples (same format as input)
# Ensure all results are dicts (same format as input)
for result in results:
assert isinstance(result, tuple)
assert len(result) == 7
assert isinstance(result, dict)
assert all(isinstance(k, TransitionKey) for k in result.keys())
def test_step_through_with_dict():
@@ -279,7 +299,7 @@ def test_hooks():
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
pipeline(transition)
assert before_calls == [0]
@@ -292,15 +312,16 @@ def test_hook_modification():
pipeline = RobotProcessor([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
obs, action, reward, done, truncated, info, comp_data = transition
return (obs, action, 42.0, done, truncated, info, comp_data)
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = 42.0
return new_transition
pipeline.register_before_step_hook(modify_reward_hook)
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert result[2] == 42.0 # reward modified by hook
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
def test_reset():
@@ -316,7 +337,7 @@ def test_reset():
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
pipeline(transition)
pipeline(transition)
@@ -335,7 +356,7 @@ def test_profile_steps():
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
profile_results = pipeline.profile_steps(transition, num_runs=10)
@@ -397,10 +418,10 @@ def test_step_without_optional_methods():
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotProcessor([step])
transition = (None, None, 2.0, False, False, {}, {})
transition = create_transition(reward=2.0)
result = pipeline(transition)
assert result[2] == 6.0 # 2.0 * 3.0
assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0
# Reset should work even if step doesn't implement reset
pipeline.reset()
@@ -419,7 +440,7 @@ def test_mixed_json_and_tensor_state():
# Process some transitions with rewards
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
# Check state
@@ -466,7 +487,7 @@ class MockModuleStep(nn.Module):
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process transition and update running mean."""
obs, action, reward, done, truncated, info, comp_data = transition
obs = transition.get(TransitionKey.OBSERVATION)
if obs is not None and isinstance(obs, torch.Tensor):
# Process observation through linear layer
@@ -509,7 +530,7 @@ def test_to_device_with_state_dict():
# Process some transitions to populate state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
# Check initial device (should be CPU)
@@ -551,7 +572,7 @@ def test_to_device_with_module():
# Process some data
obs = torch.randn(2, 5)
transition = (obs, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs, reward=1.0)
pipeline(transition)
# Check initial device
@@ -575,7 +596,7 @@ def test_to_device_with_module():
# Verify the module still works after transfer
obs_cuda = torch.randn(2, 5, device="cuda:0")
transition = (obs_cuda, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs_cuda, reward=1.0)
pipeline(transition) # Should not raise an error
@@ -589,7 +610,7 @@ def test_to_device_mixed_steps():
# Process some data
for i in range(5):
transition = (torch.randn(2, 10), None, float(i), False, False, {}, {})
transition = create_transition(observation=torch.randn(2, 10), reward=float(i))
pipeline(transition)
# Check initial state
@@ -630,7 +651,7 @@ def test_to_device_preserves_functionality():
# Process initial data
rewards = [1.0, 2.0, 3.0]
for r in rewards:
transition = (None, None, r, False, False, {}, {})
transition = create_transition(reward=r)
pipeline(transition)
# Check state before transfer
@@ -645,7 +666,7 @@ def test_to_device_preserves_functionality():
assert step.running_count == initial_count
# Process more data to ensure functionality
transition = (None, None, 4.0, False, False, {}, {})
transition = create_transition(reward=4.0)
_ = pipeline(transition)
assert step.running_count == 4
@@ -700,7 +721,8 @@ class MockNonModuleStepWithState:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process transition using tensor operations."""
obs, action, reward, done, truncated, info, comp_data = transition
obs = transition.get(TransitionKey.OBSERVATION)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim:
# Perform some tensor operations
@@ -718,7 +740,12 @@ class MockNonModuleStepWithState:
comp_data[f"{self.name}_mean_output"] = output.mean().item()
comp_data[f"{self.name}_steps"] = self.step_count.item()
return (obs, action, reward, done, truncated, info, comp_data)
# Return updated transition
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
return transition
def get_config(self) -> dict[str, Any]:
return {
@@ -763,9 +790,9 @@ def test_to_device_non_module_class():
# Process some data to populate state
for i in range(3):
obs = torch.randn(2, 5)
transition = (obs, None, float(i), False, False, {}, {})
transition = create_transition(observation=obs, reward=float(i))
result = pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert f"{non_module_step.name}_steps" in comp_data
# Verify all tensors are on CPU initially
@@ -811,9 +838,9 @@ def test_to_device_non_module_class():
# Test that step still works on GPU
obs_gpu = torch.randn(2, 5, device="cuda")
transition = (obs_gpu, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs_gpu, reward=1.0)
result = pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Verify processing worked
assert comp_data[f"{non_module_step.name}_steps"] == 4
@@ -835,7 +862,7 @@ def test_to_device_module_vs_non_module():
# Process some data
obs = torch.randn(2, 5)
transition = (obs, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs, reward=1.0)
_ = pipeline(transition)
# Check initial devices
@@ -860,7 +887,7 @@ def test_to_device_module_vs_non_module():
# Process data on GPU
obs_gpu = torch.randn(2, 5, device="cuda")
transition = (obs_gpu, None, 2.0, False, False, {}, {})
transition = create_transition(observation=obs_gpu, reward=2.0)
_ = pipeline(transition)
# Verify both steps processed the data
@@ -889,7 +916,8 @@ class MockStepWithNonSerializableParam:
self.env = env # Non-serializable parameter (like gym.Env)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Use the env parameter if provided
if self.env is not None:
@@ -897,10 +925,14 @@ class MockStepWithNonSerializableParam:
comp_data[f"{self.name}_env_info"] = str(self.env)
# Apply multiplier to reward
new_transition = transition.copy()
if reward is not None:
reward = reward * self.multiplier
new_transition[TransitionKey.REWARD] = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
if comp_data:
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
# Note: env is intentionally NOT included here as it's not serializable
@@ -928,13 +960,15 @@ class RegisteredMockStep:
device: str = "cpu"
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["registered_step_value"] = self.value
comp_data["registered_step_device"] = self.device
return (obs, action, reward, done, truncated, info, comp_data)
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
@@ -993,18 +1027,18 @@ def test_from_pretrained_with_overrides():
assert loaded_pipeline.name == "TestOverrides"
# Test the loaded steps
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
# Check that overrides were applied
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert "env_step_env_info" in comp_data
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
assert comp_data["registered_step_value"] == 200
assert comp_data["registered_step_device"] == "cuda"
# Check that multiplier override was applied
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier)
def test_from_pretrained_with_partial_overrides():
@@ -1024,13 +1058,13 @@ def test_from_pretrained_with_partial_overrides():
# Both steps will get the override
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
# The reward should be affected by both steps, both getting the override
# First step: 1.0 * 5.0 = 5.0 (overridden)
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
assert result[2] == 25.0
assert result[TransitionKey.REWARD] == 25.0
def test_from_pretrained_invalid_override_key():
@@ -1082,10 +1116,10 @@ def test_from_pretrained_registered_step_override():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test that overrides were applied
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = loaded_pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert comp_data["registered_step_value"] == 999
assert comp_data["registered_step_device"] == "cuda"
@@ -1110,13 +1144,13 @@ def test_from_pretrained_mixed_registered_and_unregistered():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test both steps
transition = (None, None, 2.0, False, False, {}, {})
transition = create_transition(reward=2.0)
result = loaded_pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
assert comp_data["registered_step_value"] == 777
assert result[2] == 8.0 # 2.0 * 4.0
assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0
def test_from_pretrained_no_overrides():
@@ -1133,10 +1167,10 @@ def test_from_pretrained_no_overrides():
assert len(loaded_pipeline) == 1
# Test that the step works (env will be None)
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
assert result[2] == 3.0 # 1.0 * 3.0
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0
def test_from_pretrained_empty_overrides():
@@ -1153,10 +1187,10 @@ def test_from_pretrained_empty_overrides():
assert len(loaded_pipeline) == 1
# Test that the step works normally
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
assert result[2] == 2.0
assert result[TransitionKey.REWARD] == 2.0
def test_from_pretrained_override_instantiation_error():
@@ -1185,7 +1219,7 @@ def test_from_pretrained_with_state_and_overrides():
# Process some data to create state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
with tempfile.TemporaryDirectory() as tmp_dir:
+41 -25
View File
@@ -13,14 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from pathlib import Path
import numpy as np
import torch
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionIndex
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_basic_renaming():
@@ -36,10 +50,10 @@ def test_basic_renaming():
"old_key2": np.array([3.0, 4.0]),
"unchanged_key": "keep_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "new_key1" in processed_obs
@@ -63,10 +77,10 @@ def test_empty_rename_map():
"key1": torch.tensor([1.0]),
"key2": "value2",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# All keys should be unchanged
assert processed_obs.keys() == observation.keys()
@@ -78,7 +92,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = RenameProcessor(rename_map={"old": "new"})
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = processor(transition)
# Should return transition unchanged
@@ -98,10 +112,10 @@ def test_overlapping_rename():
"b": 2,
"x": 3,
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that renaming happens correctly
assert "a" not in processed_obs
@@ -124,10 +138,10 @@ def test_partial_rename():
"reward": 1.0,
"info": {"episode": 1},
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "observation.proprio_state" in processed_obs
@@ -178,10 +192,12 @@ def test_integration_with_robot_processor():
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
"other_data": "preserve_me",
}
transition = (observation, None, 0.5, False, False, {}, {})
transition = create_transition(
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
)
result = pipeline(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renaming worked through pipeline
assert "observation.state" in processed_obs
@@ -191,8 +207,8 @@ def test_integration_with_robot_processor():
assert processed_obs["other_data"] == "preserve_me"
# Check other transition elements unchanged
assert result[TransitionIndex.REWARD] == 0.5
assert result[TransitionIndex.DONE] is False
assert result[TransitionKey.REWARD] == 0.5
assert result[TransitionKey.DONE] is False
def test_save_and_load_pretrained():
@@ -229,10 +245,10 @@ def test_save_and_load_pretrained():
# Test functionality after loading
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = loaded_pipeline(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
@@ -306,17 +322,17 @@ def test_chained_rename_processors():
"img": "image_data",
"extra": "keep_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
# Step through to see intermediate results
results = list(pipeline.step_through(transition))
# After first processor
assert "agent_position" in results[1][TransitionIndex.OBSERVATION]
assert "camera_image" in results[1][TransitionIndex.OBSERVATION]
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
# After second processor
final_obs = results[2][TransitionIndex.OBSERVATION]
final_obs = results[2][TransitionKey.OBSERVATION]
assert "observation.state" in final_obs
assert "observation.image" in final_obs
assert final_obs["extra"] == "keep_me"
@@ -343,10 +359,10 @@ def test_nested_observation_rename():
"observation.proprio": torch.randn(7),
"observation.gripper": torch.tensor([0.0]), # Not renamed
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renames
assert "observation.camera.left_view" in processed_obs
@@ -378,10 +394,10 @@ def test_value_types_preserved():
"old_dict": {"nested": "value"},
"old_list": [1, 2, 3],
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that values and types are preserved
assert torch.equal(processed_obs["new_tensor"], tensor_value)