mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
+257
-175
@@ -112,24 +112,23 @@ RobotProcessor solves these issues by providing a declarative pipeline approach
|
||||
|
||||
RobotProcessor works with two data formats:
|
||||
|
||||
### 1. EnvTransition Tuple Format
|
||||
### 1. EnvTransition Dictionary Format
|
||||
|
||||
An `EnvTransition` is a 7-tuple that represents a complete transition in the environment:
|
||||
An `EnvTransition` is a dictionary that represents a complete transition in the environment:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import TransitionIndex
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
# EnvTransition structure:
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
transition = (
|
||||
{"observation.image": ..., "observation.state": ...}, # observation at time t
|
||||
[0.1, -0.2, 0.3], # action taken at time t
|
||||
1.0, # reward received
|
||||
False, # episode done flag
|
||||
False, # episode truncated flag
|
||||
{"success": True}, # additional info from environment
|
||||
{"step_idx": 42} # complementary_data for inter-step communication
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {"observation.image": ..., "observation.state": ...}, # observation at time t
|
||||
TransitionKey.ACTION: [0.1, -0.2, 0.3], # action taken at time t
|
||||
TransitionKey.REWARD: 1.0, # reward received
|
||||
TransitionKey.DONE: False, # episode done flag
|
||||
TransitionKey.TRUNCATED: False, # episode truncated flag
|
||||
TransitionKey.INFO: {"success": True}, # additional info from environment
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"step_idx": 42} # complementary_data for inter-step communication
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Batch Dictionary Format
|
||||
@@ -160,9 +159,17 @@ from lerobot.processor.observation_processor import ImageProcessor
|
||||
|
||||
processor = RobotProcessor([ImageProcessor()])
|
||||
|
||||
# Works with EnvTransition tuples
|
||||
transition = ({"pixels": image_array}, None, 0.0, False, False, {}, {})
|
||||
processed_transition = processor(transition) # Returns EnvTransition tuple
|
||||
# Works with EnvTransition dictionaries
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {"pixels": image_array},
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
processed_transition = processor(transition) # Returns EnvTransition dictionary
|
||||
|
||||
# Also works with batch dictionaries
|
||||
batch = {
|
||||
@@ -176,25 +183,25 @@ batch = {
|
||||
processed_batch = processor(batch) # Returns batch dictionary
|
||||
```
|
||||
|
||||
### Using TransitionIndex
|
||||
### Using TransitionKey
|
||||
|
||||
Instead of using magic numbers to access tuple elements, use the `TransitionIndex` enum:
|
||||
Use the `TransitionKey` enum to access dictionary elements:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import TransitionIndex
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
# Bad - using magic numbers
|
||||
obs = transition[0]
|
||||
action = transition[1]
|
||||
# Good - using TransitionKey
|
||||
obs = transition[TransitionKey.OBSERVATION]
|
||||
action = transition[TransitionKey.ACTION]
|
||||
reward = transition[TransitionKey.REWARD]
|
||||
done = transition[TransitionKey.DONE]
|
||||
truncated = transition[TransitionKey.TRUNCATED]
|
||||
info = transition[TransitionKey.INFO]
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Good - using TransitionIndex
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
reward = transition[TransitionIndex.REWARD]
|
||||
done = transition[TransitionIndex.DONE]
|
||||
truncated = transition[TransitionIndex.TRUNCATED]
|
||||
info = transition[TransitionIndex.INFO]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
# Alternative - using .get() for optional access
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
```
|
||||
|
||||
### Default Conversion Functions
|
||||
@@ -203,43 +210,49 @@ RobotProcessor uses these default conversion functions:
|
||||
|
||||
```python
|
||||
def _default_batch_to_transition(batch):
|
||||
"""Default conversion from batch dict to EnvTransition tuple."""
|
||||
"""Default conversion from batch dict to EnvTransition dictionary."""
|
||||
# Extract observation keys (anything starting with "observation.")
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation = observation_keys if observation_keys else None
|
||||
|
||||
observation = None
|
||||
if observation_keys:
|
||||
observation = {}
|
||||
# Keep observation.* keys as-is (don't remove "observation." prefix)
|
||||
for key, value in observation_keys.items():
|
||||
observation[key] = value
|
||||
# Extract padding and task keys for complementary data
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||||
|
||||
return (
|
||||
observation,
|
||||
batch.get("action"),
|
||||
batch.get("next.reward", 0.0), # Note: "next.reward" not "reward"
|
||||
batch.get("next.done", False), # Note: "next.done" not "done"
|
||||
batch.get("next.truncated", False), # Note: "next.truncated" not "truncated"
|
||||
batch.get("info", {}),
|
||||
{}, # Empty complementary_data
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: batch.get("action"),
|
||||
TransitionKey.REWARD: batch.get("next.reward", 0.0),
|
||||
TransitionKey.DONE: batch.get("next.done", False),
|
||||
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
|
||||
TransitionKey.INFO: batch.get("info", {}),
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
return transition
|
||||
|
||||
def _default_transition_to_batch(transition):
|
||||
"""Default conversion from EnvTransition tuple to batch dict."""
|
||||
obs, action, reward, done, truncated, info, _ = transition
|
||||
|
||||
"""Default conversion from EnvTransition dictionary to batch dict."""
|
||||
batch = {
|
||||
"action": action,
|
||||
"next.reward": reward, # Note: "next.reward" not "reward"
|
||||
"next.done": done, # Note: "next.done" not "done"
|
||||
"next.truncated": truncated, # Note: "next.truncated" not "truncated"
|
||||
"info": info,
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Flatten observation dict (keep observation.* keys as-is)
|
||||
if isinstance(obs, dict):
|
||||
for key, value in obs.items():
|
||||
batch[key] = value
|
||||
# Add padding and task data from complementary_data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||
batch.update(pad_data)
|
||||
if "task" in complementary_data:
|
||||
batch["task"] = complementary_data["task"]
|
||||
|
||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
return batch
|
||||
```
|
||||
@@ -250,33 +263,32 @@ You can customize how RobotProcessor converts between formats:
|
||||
|
||||
```python
|
||||
def custom_batch_to_transition(batch):
|
||||
"""Custom conversion from batch dict to EnvTransition tuple."""
|
||||
"""Custom conversion from batch dict to EnvTransition dictionary."""
|
||||
# Extract observation keys (anything starting with "observation.")
|
||||
observation = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
return (
|
||||
observation,
|
||||
batch.get("action"),
|
||||
batch.get("reward", 0.0), # Use "reward" instead of "next.reward"
|
||||
batch.get("done", False), # Use "done" instead of "next.done"
|
||||
batch.get("truncated", False),
|
||||
batch.get("info", {}),
|
||||
batch.get("complementary_data", {})
|
||||
)
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: batch.get("action"),
|
||||
TransitionKey.REWARD: batch.get("reward", 0.0), # Use "reward" instead of "next.reward"
|
||||
TransitionKey.DONE: batch.get("done", False), # Use "done" instead of "next.done"
|
||||
TransitionKey.TRUNCATED: batch.get("truncated", False),
|
||||
TransitionKey.INFO: batch.get("info", {}),
|
||||
TransitionKey.COMPLEMENTARY_DATA: batch.get("complementary_data", {})
|
||||
}
|
||||
|
||||
def custom_transition_to_batch(transition):
|
||||
"""Custom conversion from EnvTransition tuple to batch dict."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
"""Custom conversion from EnvTransition dictionary to batch dict."""
|
||||
batch = {
|
||||
"action": action,
|
||||
"reward": reward, # Use "reward" instead of "next.reward"
|
||||
"done": done, # Use "done" instead of "next.done"
|
||||
"truncated": truncated,
|
||||
"info": info,
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"reward": transition.get(TransitionKey.REWARD), # Use "reward" instead of "next.reward"
|
||||
"done": transition.get(TransitionKey.DONE), # Use "done" instead of "next.done"
|
||||
"truncated": transition.get(TransitionKey.TRUNCATED),
|
||||
"info": transition.get(TransitionKey.INFO),
|
||||
}
|
||||
|
||||
# Flatten observation dict
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
batch.update(obs)
|
||||
|
||||
@@ -292,21 +304,21 @@ processor = RobotProcessor(
|
||||
|
||||
### Advanced: Controlling Output Format with `to_output`
|
||||
|
||||
The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` tuples even when you input batch dictionaries:
|
||||
The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` dictionaries even when you input batch dictionaries:
|
||||
|
||||
```python
|
||||
# Identity function to always return EnvTransition tuples
|
||||
# Identity function to always return EnvTransition dictionaries
|
||||
def keep_as_transition(transition):
|
||||
"""Always return EnvTransition tuple regardless of input format."""
|
||||
"""Always return EnvTransition dictionary regardless of input format."""
|
||||
return transition
|
||||
|
||||
# Processor that always outputs EnvTransition tuples
|
||||
# Processor that always outputs EnvTransition dictionaries
|
||||
processor = RobotProcessor(
|
||||
steps=[ImageProcessor(), StateProcessor()],
|
||||
to_output=keep_as_transition # Always return tuple format
|
||||
to_output=keep_as_transition # Always return dictionary format
|
||||
)
|
||||
|
||||
# Even when called with batch dict, returns EnvTransition tuple
|
||||
# Even when called with batch dict, returns EnvTransition dictionary
|
||||
batch = {
|
||||
"observation.image": image_tensor,
|
||||
"action": action_tensor,
|
||||
@@ -316,13 +328,13 @@ batch = {
|
||||
"info": info_dict
|
||||
}
|
||||
|
||||
result = processor(batch) # Returns EnvTransition tuple, not batch dict!
|
||||
print(type(result)) # <class 'tuple'>
|
||||
result = processor(batch) # Returns EnvTransition dictionary, not batch dict!
|
||||
print(type(result)) # <class 'dict'>
|
||||
```
|
||||
|
||||
### Real-World Example: Environment Interaction
|
||||
|
||||
This is particularly useful for environment interaction where you want consistent tuple output:
|
||||
This is particularly useful for environment interaction where you want consistent dictionary output:
|
||||
|
||||
```python
|
||||
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
||||
@@ -332,7 +344,7 @@ from lerobot.processor.observation_processor import VanillaObservationProcessor
|
||||
env_processor = RobotProcessor(
|
||||
[VanillaObservationProcessor()],
|
||||
to_transition=lambda x: x, # Pass through - no conversion needed
|
||||
to_output=lambda x: x, # Always return EnvTransition tuple
|
||||
to_output=lambda x: x, # Always return EnvTransition dictionary
|
||||
)
|
||||
|
||||
# Environment interaction loop
|
||||
@@ -340,12 +352,20 @@ env = make_env()
|
||||
obs, info = env.reset()
|
||||
|
||||
for step in range(1000):
|
||||
# Create transition - input is already in tuple format
|
||||
transition = (obs, None, 0.0, False, False, info, {"step": step})
|
||||
# Create transition - input is already in dictionary format
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"step": step}
|
||||
}
|
||||
|
||||
# Process - output is guaranteed to be EnvTransition tuple
|
||||
# Process - output is guaranteed to be EnvTransition dictionary
|
||||
processed_transition = env_processor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Use with policy
|
||||
action = policy.select_action(processed_obs)
|
||||
@@ -357,12 +377,12 @@ for step in range(1000):
|
||||
|
||||
### When to Use Different Output Formats
|
||||
|
||||
**Use EnvTransition tuple output when:**
|
||||
**Use EnvTransition dictionary output when:**
|
||||
|
||||
- Environment interaction and real-time control
|
||||
- You need to access individual transition components frequently
|
||||
- Performance is critical (avoids dictionary creation overhead)
|
||||
- Working with gym environments that expect tuple format
|
||||
- Working with gym environments that expect structured data
|
||||
- You need the flexibility of dictionary operations
|
||||
|
||||
**Use batch dictionary output when:**
|
||||
|
||||
@@ -372,10 +392,10 @@ for step in range(1000):
|
||||
- You need the standardized "next.\*" key format
|
||||
|
||||
```python
|
||||
# For environment interaction - use tuple output
|
||||
# For environment interaction - use dictionary output
|
||||
env_processor = RobotProcessor(
|
||||
steps=[ImageProcessor(), StateProcessor()],
|
||||
to_output=lambda x: x # Return EnvTransition tuple
|
||||
to_output=lambda x: x # Return EnvTransition dictionary
|
||||
)
|
||||
|
||||
# For training - use batch output (default)
|
||||
@@ -391,9 +411,17 @@ for batch in dataloader:
|
||||
|
||||
# Environment loop
|
||||
for step in range(1000):
|
||||
transition = (obs, None, 0.0, False, False, info, {})
|
||||
processed_transition = env_processor(transition) # Returns EnvTransition tuple
|
||||
obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
processed_transition = env_processor(transition) # Returns EnvTransition dictionary
|
||||
obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
action = policy.select_action(obs)
|
||||
```
|
||||
|
||||
@@ -426,7 +454,7 @@ batch = {
|
||||
Let's create a processor that properly handles image and state preprocessing:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
||||
from lerobot.processor.observation_processor import ImageProcessor, StateProcessor
|
||||
import numpy as np
|
||||
|
||||
@@ -440,7 +468,15 @@ observation = {
|
||||
}
|
||||
|
||||
# Create a full transition
|
||||
transition = (observation, None, 0.0, False, False, {}, {})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
|
||||
# Create and use the processor
|
||||
processor = RobotProcessor([
|
||||
@@ -449,7 +485,7 @@ processor = RobotProcessor([
|
||||
])
|
||||
|
||||
processed_transition = processor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check the results
|
||||
print("Original keys:", observation.keys())
|
||||
@@ -541,11 +577,19 @@ obs, info = env.reset()
|
||||
|
||||
for step in range(1000):
|
||||
# Raw environment observation
|
||||
transition = (obs, None, 0.0, False, False, info, {})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
|
||||
# Process for policy input
|
||||
processed_transition = online_processor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Get action from policy
|
||||
action = policy.select_action(processed_obs)
|
||||
@@ -585,15 +629,16 @@ class ImagePadder:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Main processing method - required for all steps."""
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
# Process all image observations
|
||||
for key in list(obs.keys()):
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
img = obs[key]
|
||||
img = processed_obs[key]
|
||||
# Calculate padding
|
||||
_, _, h, w = img.shape
|
||||
pad_h = max(0, self.target_height - h)
|
||||
@@ -609,10 +654,12 @@ class ImagePadder:
|
||||
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
|
||||
mode='constant', value=self.pad_value)
|
||||
|
||||
obs[key] = img
|
||||
processed_obs[key] = img
|
||||
|
||||
# Return modified transition
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Return JSON-serializable configuration - required for save/load."""
|
||||
@@ -694,8 +741,8 @@ class ImageStatisticsCalculator:
|
||||
"""Calculate image statistics and pass to next steps."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
@@ -714,18 +761,13 @@ class ImageStatisticsCalculator:
|
||||
image_stats[key] = stats
|
||||
|
||||
# Store in complementary_data for next steps
|
||||
comp_data = dict(comp_data) # Make a copy
|
||||
comp_data["image_statistics"] = image_stats
|
||||
|
||||
# Return transition with updated complementary_data
|
||||
return (
|
||||
obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
comp_data # Updated complementary_data
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
return new_transition
|
||||
|
||||
@dataclass
|
||||
class AdaptiveBrightnessAdjuster:
|
||||
@@ -734,8 +776,8 @@ class AdaptiveBrightnessAdjuster:
|
||||
target_brightness: float = 0.5
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if obs is None or "image_statistics" not in comp_data:
|
||||
return transition
|
||||
@@ -743,15 +785,18 @@ class AdaptiveBrightnessAdjuster:
|
||||
# Use statistics from previous step
|
||||
image_stats = comp_data["image_statistics"]
|
||||
|
||||
for key in obs:
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key in processed_obs:
|
||||
if key.startswith("observation.images.") and key in image_stats:
|
||||
current_mean = image_stats[key]["mean"]
|
||||
brightness_adjust = self.target_brightness - current_mean
|
||||
|
||||
# Adjust brightness
|
||||
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
|
||||
processed_obs[key] = torch.clamp(processed_obs[key] + brightness_adjust, 0, 1)
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
# Use them together
|
||||
processor = RobotProcessor([
|
||||
@@ -782,7 +827,7 @@ class ActionRepeatStep:
|
||||
env: gym.Env = None # This can't be serialized to JSON!
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
|
||||
if self.env is not None and action is not None:
|
||||
# Repeat action multiple times in environment
|
||||
@@ -792,9 +837,13 @@ class ActionRepeatStep:
|
||||
total_reward += r
|
||||
if d or t:
|
||||
break
|
||||
reward = total_reward
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
# Update reward in transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = total_reward
|
||||
return new_transition
|
||||
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Note: env is NOT included because it's not serializable
|
||||
@@ -1211,7 +1260,7 @@ This enables sharing of preprocessing logic while allowing each user to provide
|
||||
Here's a complete example showing proper device management and all features:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
@@ -1224,23 +1273,29 @@ class DeviceMover:
|
||||
device: str = "cuda"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
# Move all tensor observations to device
|
||||
for key, value in obs.items():
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key, value in processed_obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
obs[key] = value.to(self.device)
|
||||
processed_obs[key] = value.to(self.device)
|
||||
|
||||
# Also handle action if present
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
action = action.to(self.device)
|
||||
return (obs, action, *transition[2:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"device": str(self.device)}
|
||||
@@ -1260,7 +1315,7 @@ class RunningNormalizer:
|
||||
self.initialized = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if obs is None or "observation.state" not in obs:
|
||||
return transition
|
||||
@@ -1284,9 +1339,12 @@ class RunningNormalizer:
|
||||
|
||||
# Normalize
|
||||
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
|
||||
obs["observation.state"] = state_normalized
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
processed_obs["observation.state"] = state_normalized
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
@@ -1317,7 +1375,7 @@ class RunningNormalizer:
|
||||
processor = RobotProcessor([
|
||||
ImageProcessor(), # Convert images to float32 [0,1]
|
||||
StateProcessor(), # Convert states to torch tensors
|
||||
ImagePadder(224, 224), # Pad images to standard size
|
||||
ImagePadder(target_height=224, target_width=224), # Pad images to standard size
|
||||
DeviceMover("cuda"), # Move everything to GPU
|
||||
RunningNormalizer(7), # Normalize states
|
||||
], name="CompletePreprocessor")
|
||||
@@ -1330,11 +1388,19 @@ obs = {
|
||||
"pixels": {"cam": np.random.randint(0, 255, (200, 300, 3), dtype=np.uint8)},
|
||||
"agent_pos": np.random.randn(7).astype(np.float32)
|
||||
}
|
||||
transition = (obs, None, 0.0, False, False, {}, {})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
|
||||
# Everything is processed and on GPU
|
||||
processed = processor(transition)
|
||||
print(processed[TransitionIndex.OBSERVATION]["observation.images.cam"].device) # cuda:0
|
||||
print(processed[TransitionKey.OBSERVATION]["observation.images.cam"].device) # cuda:0
|
||||
```
|
||||
|
||||
## Solving Real-World Problems with RobotProcessor
|
||||
@@ -1358,22 +1424,25 @@ class KeyRemapper:
|
||||
})
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
# Create new observation with renamed keys
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
renamed_obs = {}
|
||||
for old_key, new_key in self.key_mapping.items():
|
||||
if old_key in obs:
|
||||
renamed_obs[new_key] = obs[old_key]
|
||||
if old_key in processed_obs:
|
||||
renamed_obs[new_key] = processed_obs[old_key]
|
||||
|
||||
# Keep any unmapped keys as-is
|
||||
for key, value in obs.items():
|
||||
for key, value in processed_obs.items():
|
||||
if key not in self.key_mapping:
|
||||
renamed_obs[key] = value
|
||||
|
||||
return (renamed_obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = renamed_obs
|
||||
return new_transition
|
||||
```
|
||||
|
||||
### Workspace-Focused Image Processing
|
||||
@@ -1390,13 +1459,14 @@ class WorkspaceCropper:
|
||||
output_size: Tuple[int, int] = (224, 224)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
for key in list(obs.keys()):
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
img = obs[key]
|
||||
img = processed_obs[key]
|
||||
# Crop to workspace
|
||||
x1, y1, x2, y2 = self.crop_bbox
|
||||
img_cropped = img[:, :, y1:y2, x1:x2]
|
||||
@@ -1407,9 +1477,11 @@ class WorkspaceCropper:
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
obs[key] = img_resized
|
||||
processed_obs[key] = img_resized
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
```
|
||||
|
||||
### Building Complete Pipelines for Different Robots
|
||||
@@ -1471,7 +1543,7 @@ The beauty of this approach is that:
|
||||
|
||||
```python
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
# Always check if observation exists
|
||||
if obs is None:
|
||||
@@ -1496,7 +1568,7 @@ return (modified_obs, None, 0.0, False, False, {}, {})
|
||||
|
||||
```python
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if self.store_previous:
|
||||
# Good - clone to avoid reference issues
|
||||
@@ -1522,7 +1594,7 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
Here's how to use RobotProcessor in a real robot control loop, showing both tuple and batch formats:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from pathlib import Path
|
||||
import time
|
||||
@@ -1545,11 +1617,13 @@ class ActionClipper:
|
||||
max_value: float = 1.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
|
||||
if action is not None:
|
||||
action = torch.clamp(action, self.min_value, self.max_value)
|
||||
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
return transition
|
||||
|
||||
@@ -1578,28 +1652,36 @@ for episode in range(10):
|
||||
|
||||
for step in range(1000):
|
||||
# Create transition with raw observation
|
||||
transition = (obs, None, 0.0, False, False, info, {"step": step})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"step": step}
|
||||
}
|
||||
|
||||
# Preprocess - works with tuple format
|
||||
# Preprocess - works with dictionary format
|
||||
processed_transition = preprocessor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
# Get action from policy
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(processed_obs)
|
||||
|
||||
# Postprocess action
|
||||
action_transition = (
|
||||
processed_obs,
|
||||
action,
|
||||
0.0,
|
||||
False,
|
||||
False,
|
||||
info,
|
||||
{"raw_action": action.clone()} # Store raw action in complementary_data
|
||||
)
|
||||
action_transition = {
|
||||
TransitionKey.OBSERVATION: processed_obs,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"raw_action": action.clone()} # Store raw action in complementary_data
|
||||
}
|
||||
processed_action_transition = postprocessor(action_transition)
|
||||
final_action = processed_action_transition[TransitionIndex.ACTION]
|
||||
final_action = processed_action_transition.get(TransitionKey.ACTION)
|
||||
|
||||
# Execute action
|
||||
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
|
||||
@@ -1667,7 +1749,7 @@ Use the full power of `RobotProcessor` for debugging:
|
||||
```python
|
||||
# Enable detailed logging
|
||||
def log_observation_shapes(step_idx: int, transition: EnvTransition):
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
print(f"Step {step_idx} observations:")
|
||||
for key, value in obs.items():
|
||||
@@ -1679,7 +1761,7 @@ processor.register_after_step_hook(log_observation_shapes)
|
||||
|
||||
# Monitor complementary data flow
|
||||
def monitor_complementary_data(step_idx: int, transition: EnvTransition):
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if comp_data:
|
||||
print(f"Step {step_idx} complementary_data: {list(comp_data.keys())}")
|
||||
return None
|
||||
@@ -1688,7 +1770,7 @@ processor.register_before_step_hook(monitor_complementary_data)
|
||||
|
||||
# Validate data integrity
|
||||
def validate_tensors(step_idx: int, transition: EnvTransition):
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
@@ -1705,21 +1787,21 @@ processor.register_after_step_hook(validate_tensors)
|
||||
|
||||
RobotProcessor provides a powerful, modular approach to data preprocessing in robotics:
|
||||
|
||||
- **Dual format support**: Works seamlessly with both EnvTransition tuples and batch dictionaries
|
||||
- **Automatic format conversion**: Converts between tuple and batch formats as needed
|
||||
- **Dual format support**: Works seamlessly with both EnvTransition dictionaries and batch dictionaries
|
||||
- **Automatic format conversion**: Converts between dictionary and batch formats as needed
|
||||
- **LeRobot integration**: Native support for LeRobotDataset and ReplayBuffer formats
|
||||
- **Clear separation of concerns**: Each transformation is a separate, testable unit
|
||||
- **Proper state management**: Clear distinction between config (JSON) and state (tensors)
|
||||
- **Device-aware**: Seamless GPU/CPU transfers with `.to(device)`
|
||||
- **Inter-step communication**: Use `complementary_data` for passing information
|
||||
- **Easy sharing**: Push to Hugging Face Hub for reproducibility
|
||||
- **Type safety**: Use `TransitionIndex` instead of magic numbers
|
||||
- **Type safety**: Use `TransitionKey` instead of magic numbers
|
||||
- **Debugging tools**: Step through transformations and add monitoring hooks
|
||||
- **Flexible conversion**: Customize `to_transition` and `to_output` functions for specific needs
|
||||
|
||||
Key advantages of the dual format approach:
|
||||
|
||||
- **Environment interaction**: Use tuple format for real-time robot control
|
||||
- **Environment interaction**: Use dictionary format for real-time robot control
|
||||
- **Training/evaluation**: Use batch format for dataset processing and model training
|
||||
- **Seamless integration**: Same processor works with both formats automatically
|
||||
- **Backward compatibility**: Existing code using either format continues to work
|
||||
|
||||
Reference in New Issue
Block a user