mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
+257
-175
@@ -112,24 +112,23 @@ RobotProcessor solves these issues by providing a declarative pipeline approach
|
||||
|
||||
RobotProcessor works with two data formats:
|
||||
|
||||
### 1. EnvTransition Tuple Format
|
||||
### 1. EnvTransition Dictionary Format
|
||||
|
||||
An `EnvTransition` is a 7-tuple that represents a complete transition in the environment:
|
||||
An `EnvTransition` is a dictionary that represents a complete transition in the environment:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import TransitionIndex
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
# EnvTransition structure:
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
transition = (
|
||||
{"observation.image": ..., "observation.state": ...}, # observation at time t
|
||||
[0.1, -0.2, 0.3], # action taken at time t
|
||||
1.0, # reward received
|
||||
False, # episode done flag
|
||||
False, # episode truncated flag
|
||||
{"success": True}, # additional info from environment
|
||||
{"step_idx": 42} # complementary_data for inter-step communication
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {"observation.image": ..., "observation.state": ...}, # observation at time t
|
||||
TransitionKey.ACTION: [0.1, -0.2, 0.3], # action taken at time t
|
||||
TransitionKey.REWARD: 1.0, # reward received
|
||||
TransitionKey.DONE: False, # episode done flag
|
||||
TransitionKey.TRUNCATED: False, # episode truncated flag
|
||||
TransitionKey.INFO: {"success": True}, # additional info from environment
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"step_idx": 42} # complementary_data for inter-step communication
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Batch Dictionary Format
|
||||
@@ -160,9 +159,17 @@ from lerobot.processor.observation_processor import ImageProcessor
|
||||
|
||||
processor = RobotProcessor([ImageProcessor()])
|
||||
|
||||
# Works with EnvTransition tuples
|
||||
transition = ({"pixels": image_array}, None, 0.0, False, False, {}, {})
|
||||
processed_transition = processor(transition) # Returns EnvTransition tuple
|
||||
# Works with EnvTransition dictionaries
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {"pixels": image_array},
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
processed_transition = processor(transition) # Returns EnvTransition dictionary
|
||||
|
||||
# Also works with batch dictionaries
|
||||
batch = {
|
||||
@@ -176,25 +183,25 @@ batch = {
|
||||
processed_batch = processor(batch) # Returns batch dictionary
|
||||
```
|
||||
|
||||
### Using TransitionIndex
|
||||
### Using TransitionKey
|
||||
|
||||
Instead of using magic numbers to access tuple elements, use the `TransitionIndex` enum:
|
||||
Use the `TransitionKey` enum to access dictionary elements:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import TransitionIndex
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
# Bad - using magic numbers
|
||||
obs = transition[0]
|
||||
action = transition[1]
|
||||
# Good - using TransitionKey
|
||||
obs = transition[TransitionKey.OBSERVATION]
|
||||
action = transition[TransitionKey.ACTION]
|
||||
reward = transition[TransitionKey.REWARD]
|
||||
done = transition[TransitionKey.DONE]
|
||||
truncated = transition[TransitionKey.TRUNCATED]
|
||||
info = transition[TransitionKey.INFO]
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Good - using TransitionIndex
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
reward = transition[TransitionIndex.REWARD]
|
||||
done = transition[TransitionIndex.DONE]
|
||||
truncated = transition[TransitionIndex.TRUNCATED]
|
||||
info = transition[TransitionIndex.INFO]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
# Alternative - using .get() for optional access
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
```
|
||||
|
||||
### Default Conversion Functions
|
||||
@@ -203,43 +210,49 @@ RobotProcessor uses these default conversion functions:
|
||||
|
||||
```python
|
||||
def _default_batch_to_transition(batch):
|
||||
"""Default conversion from batch dict to EnvTransition tuple."""
|
||||
"""Default conversion from batch dict to EnvTransition dictionary."""
|
||||
# Extract observation keys (anything starting with "observation.")
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation = observation_keys if observation_keys else None
|
||||
|
||||
observation = None
|
||||
if observation_keys:
|
||||
observation = {}
|
||||
# Keep observation.* keys as-is (don't remove "observation." prefix)
|
||||
for key, value in observation_keys.items():
|
||||
observation[key] = value
|
||||
# Extract padding and task keys for complementary data
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||||
|
||||
return (
|
||||
observation,
|
||||
batch.get("action"),
|
||||
batch.get("next.reward", 0.0), # Note: "next.reward" not "reward"
|
||||
batch.get("next.done", False), # Note: "next.done" not "done"
|
||||
batch.get("next.truncated", False), # Note: "next.truncated" not "truncated"
|
||||
batch.get("info", {}),
|
||||
{}, # Empty complementary_data
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: batch.get("action"),
|
||||
TransitionKey.REWARD: batch.get("next.reward", 0.0),
|
||||
TransitionKey.DONE: batch.get("next.done", False),
|
||||
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
|
||||
TransitionKey.INFO: batch.get("info", {}),
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
return transition
|
||||
|
||||
def _default_transition_to_batch(transition):
|
||||
"""Default conversion from EnvTransition tuple to batch dict."""
|
||||
obs, action, reward, done, truncated, info, _ = transition
|
||||
|
||||
"""Default conversion from EnvTransition dictionary to batch dict."""
|
||||
batch = {
|
||||
"action": action,
|
||||
"next.reward": reward, # Note: "next.reward" not "reward"
|
||||
"next.done": done, # Note: "next.done" not "done"
|
||||
"next.truncated": truncated, # Note: "next.truncated" not "truncated"
|
||||
"info": info,
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Flatten observation dict (keep observation.* keys as-is)
|
||||
if isinstance(obs, dict):
|
||||
for key, value in obs.items():
|
||||
batch[key] = value
|
||||
# Add padding and task data from complementary_data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||
batch.update(pad_data)
|
||||
if "task" in complementary_data:
|
||||
batch["task"] = complementary_data["task"]
|
||||
|
||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
return batch
|
||||
```
|
||||
@@ -250,33 +263,32 @@ You can customize how RobotProcessor converts between formats:
|
||||
|
||||
```python
|
||||
def custom_batch_to_transition(batch):
|
||||
"""Custom conversion from batch dict to EnvTransition tuple."""
|
||||
"""Custom conversion from batch dict to EnvTransition dictionary."""
|
||||
# Extract observation keys (anything starting with "observation.")
|
||||
observation = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
return (
|
||||
observation,
|
||||
batch.get("action"),
|
||||
batch.get("reward", 0.0), # Use "reward" instead of "next.reward"
|
||||
batch.get("done", False), # Use "done" instead of "next.done"
|
||||
batch.get("truncated", False),
|
||||
batch.get("info", {}),
|
||||
batch.get("complementary_data", {})
|
||||
)
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: batch.get("action"),
|
||||
TransitionKey.REWARD: batch.get("reward", 0.0), # Use "reward" instead of "next.reward"
|
||||
TransitionKey.DONE: batch.get("done", False), # Use "done" instead of "next.done"
|
||||
TransitionKey.TRUNCATED: batch.get("truncated", False),
|
||||
TransitionKey.INFO: batch.get("info", {}),
|
||||
TransitionKey.COMPLEMENTARY_DATA: batch.get("complementary_data", {})
|
||||
}
|
||||
|
||||
def custom_transition_to_batch(transition):
|
||||
"""Custom conversion from EnvTransition tuple to batch dict."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
"""Custom conversion from EnvTransition dictionary to batch dict."""
|
||||
batch = {
|
||||
"action": action,
|
||||
"reward": reward, # Use "reward" instead of "next.reward"
|
||||
"done": done, # Use "done" instead of "next.done"
|
||||
"truncated": truncated,
|
||||
"info": info,
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"reward": transition.get(TransitionKey.REWARD), # Use "reward" instead of "next.reward"
|
||||
"done": transition.get(TransitionKey.DONE), # Use "done" instead of "next.done"
|
||||
"truncated": transition.get(TransitionKey.TRUNCATED),
|
||||
"info": transition.get(TransitionKey.INFO),
|
||||
}
|
||||
|
||||
# Flatten observation dict
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
batch.update(obs)
|
||||
|
||||
@@ -292,21 +304,21 @@ processor = RobotProcessor(
|
||||
|
||||
### Advanced: Controlling Output Format with `to_output`
|
||||
|
||||
The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` tuples even when you input batch dictionaries:
|
||||
The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` dictionaries even when you input batch dictionaries:
|
||||
|
||||
```python
|
||||
# Identity function to always return EnvTransition tuples
|
||||
# Identity function to always return EnvTransition dictionaries
|
||||
def keep_as_transition(transition):
|
||||
"""Always return EnvTransition tuple regardless of input format."""
|
||||
"""Always return EnvTransition dictionary regardless of input format."""
|
||||
return transition
|
||||
|
||||
# Processor that always outputs EnvTransition tuples
|
||||
# Processor that always outputs EnvTransition dictionaries
|
||||
processor = RobotProcessor(
|
||||
steps=[ImageProcessor(), StateProcessor()],
|
||||
to_output=keep_as_transition # Always return tuple format
|
||||
to_output=keep_as_transition # Always return dictionary format
|
||||
)
|
||||
|
||||
# Even when called with batch dict, returns EnvTransition tuple
|
||||
# Even when called with batch dict, returns EnvTransition dictionary
|
||||
batch = {
|
||||
"observation.image": image_tensor,
|
||||
"action": action_tensor,
|
||||
@@ -316,13 +328,13 @@ batch = {
|
||||
"info": info_dict
|
||||
}
|
||||
|
||||
result = processor(batch) # Returns EnvTransition tuple, not batch dict!
|
||||
print(type(result)) # <class 'tuple'>
|
||||
result = processor(batch) # Returns EnvTransition dictionary, not batch dict!
|
||||
print(type(result)) # <class 'dict'>
|
||||
```
|
||||
|
||||
### Real-World Example: Environment Interaction
|
||||
|
||||
This is particularly useful for environment interaction where you want consistent tuple output:
|
||||
This is particularly useful for environment interaction where you want consistent dictionary output:
|
||||
|
||||
```python
|
||||
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
||||
@@ -332,7 +344,7 @@ from lerobot.processor.observation_processor import VanillaObservationProcessor
|
||||
env_processor = RobotProcessor(
|
||||
[VanillaObservationProcessor()],
|
||||
to_transition=lambda x: x, # Pass through - no conversion needed
|
||||
to_output=lambda x: x, # Always return EnvTransition tuple
|
||||
to_output=lambda x: x, # Always return EnvTransition dictionary
|
||||
)
|
||||
|
||||
# Environment interaction loop
|
||||
@@ -340,12 +352,20 @@ env = make_env()
|
||||
obs, info = env.reset()
|
||||
|
||||
for step in range(1000):
|
||||
# Create transition - input is already in tuple format
|
||||
transition = (obs, None, 0.0, False, False, info, {"step": step})
|
||||
# Create transition - input is already in dictionary format
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"step": step}
|
||||
}
|
||||
|
||||
# Process - output is guaranteed to be EnvTransition tuple
|
||||
# Process - output is guaranteed to be EnvTransition dictionary
|
||||
processed_transition = env_processor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Use with policy
|
||||
action = policy.select_action(processed_obs)
|
||||
@@ -357,12 +377,12 @@ for step in range(1000):
|
||||
|
||||
### When to Use Different Output Formats
|
||||
|
||||
**Use EnvTransition tuple output when:**
|
||||
**Use EnvTransition dictionary output when:**
|
||||
|
||||
- Environment interaction and real-time control
|
||||
- You need to access individual transition components frequently
|
||||
- Performance is critical (avoids dictionary creation overhead)
|
||||
- Working with gym environments that expect tuple format
|
||||
- Working with gym environments that expect structured data
|
||||
- You need the flexibility of dictionary operations
|
||||
|
||||
**Use batch dictionary output when:**
|
||||
|
||||
@@ -372,10 +392,10 @@ for step in range(1000):
|
||||
- You need the standardized "next.\*" key format
|
||||
|
||||
```python
|
||||
# For environment interaction - use tuple output
|
||||
# For environment interaction - use dictionary output
|
||||
env_processor = RobotProcessor(
|
||||
steps=[ImageProcessor(), StateProcessor()],
|
||||
to_output=lambda x: x # Return EnvTransition tuple
|
||||
to_output=lambda x: x # Return EnvTransition dictionary
|
||||
)
|
||||
|
||||
# For training - use batch output (default)
|
||||
@@ -391,9 +411,17 @@ for batch in dataloader:
|
||||
|
||||
# Environment loop
|
||||
for step in range(1000):
|
||||
transition = (obs, None, 0.0, False, False, info, {})
|
||||
processed_transition = env_processor(transition) # Returns EnvTransition tuple
|
||||
obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
processed_transition = env_processor(transition) # Returns EnvTransition dictionary
|
||||
obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
action = policy.select_action(obs)
|
||||
```
|
||||
|
||||
@@ -426,7 +454,7 @@ batch = {
|
||||
Let's create a processor that properly handles image and state preprocessing:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
||||
from lerobot.processor.observation_processor import ImageProcessor, StateProcessor
|
||||
import numpy as np
|
||||
|
||||
@@ -440,7 +468,15 @@ observation = {
|
||||
}
|
||||
|
||||
# Create a full transition
|
||||
transition = (observation, None, 0.0, False, False, {}, {})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
|
||||
# Create and use the processor
|
||||
processor = RobotProcessor([
|
||||
@@ -449,7 +485,7 @@ processor = RobotProcessor([
|
||||
])
|
||||
|
||||
processed_transition = processor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check the results
|
||||
print("Original keys:", observation.keys())
|
||||
@@ -541,11 +577,19 @@ obs, info = env.reset()
|
||||
|
||||
for step in range(1000):
|
||||
# Raw environment observation
|
||||
transition = (obs, None, 0.0, False, False, info, {})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
|
||||
# Process for policy input
|
||||
processed_transition = online_processor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Get action from policy
|
||||
action = policy.select_action(processed_obs)
|
||||
@@ -585,15 +629,16 @@ class ImagePadder:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Main processing method - required for all steps."""
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
# Process all image observations
|
||||
for key in list(obs.keys()):
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
img = obs[key]
|
||||
img = processed_obs[key]
|
||||
# Calculate padding
|
||||
_, _, h, w = img.shape
|
||||
pad_h = max(0, self.target_height - h)
|
||||
@@ -609,10 +654,12 @@ class ImagePadder:
|
||||
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
|
||||
mode='constant', value=self.pad_value)
|
||||
|
||||
obs[key] = img
|
||||
processed_obs[key] = img
|
||||
|
||||
# Return modified transition
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Return JSON-serializable configuration - required for save/load."""
|
||||
@@ -694,8 +741,8 @@ class ImageStatisticsCalculator:
|
||||
"""Calculate image statistics and pass to next steps."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
@@ -714,18 +761,13 @@ class ImageStatisticsCalculator:
|
||||
image_stats[key] = stats
|
||||
|
||||
# Store in complementary_data for next steps
|
||||
comp_data = dict(comp_data) # Make a copy
|
||||
comp_data["image_statistics"] = image_stats
|
||||
|
||||
# Return transition with updated complementary_data
|
||||
return (
|
||||
obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
comp_data # Updated complementary_data
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
return new_transition
|
||||
|
||||
@dataclass
|
||||
class AdaptiveBrightnessAdjuster:
|
||||
@@ -734,8 +776,8 @@ class AdaptiveBrightnessAdjuster:
|
||||
target_brightness: float = 0.5
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if obs is None or "image_statistics" not in comp_data:
|
||||
return transition
|
||||
@@ -743,15 +785,18 @@ class AdaptiveBrightnessAdjuster:
|
||||
# Use statistics from previous step
|
||||
image_stats = comp_data["image_statistics"]
|
||||
|
||||
for key in obs:
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key in processed_obs:
|
||||
if key.startswith("observation.images.") and key in image_stats:
|
||||
current_mean = image_stats[key]["mean"]
|
||||
brightness_adjust = self.target_brightness - current_mean
|
||||
|
||||
# Adjust brightness
|
||||
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
|
||||
processed_obs[key] = torch.clamp(processed_obs[key] + brightness_adjust, 0, 1)
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
# Use them together
|
||||
processor = RobotProcessor([
|
||||
@@ -782,7 +827,7 @@ class ActionRepeatStep:
|
||||
env: gym.Env = None # This can't be serialized to JSON!
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
|
||||
if self.env is not None and action is not None:
|
||||
# Repeat action multiple times in environment
|
||||
@@ -792,9 +837,13 @@ class ActionRepeatStep:
|
||||
total_reward += r
|
||||
if d or t:
|
||||
break
|
||||
reward = total_reward
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
# Update reward in transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = total_reward
|
||||
return new_transition
|
||||
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Note: env is NOT included because it's not serializable
|
||||
@@ -1211,7 +1260,7 @@ This enables sharing of preprocessing logic while allowing each user to provide
|
||||
Here's a complete example showing proper device management and all features:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
@@ -1224,23 +1273,29 @@ class DeviceMover:
|
||||
device: str = "cuda"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
# Move all tensor observations to device
|
||||
for key, value in obs.items():
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key, value in processed_obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
obs[key] = value.to(self.device)
|
||||
processed_obs[key] = value.to(self.device)
|
||||
|
||||
# Also handle action if present
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
action = action.to(self.device)
|
||||
return (obs, action, *transition[2:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"device": str(self.device)}
|
||||
@@ -1260,7 +1315,7 @@ class RunningNormalizer:
|
||||
self.initialized = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if obs is None or "observation.state" not in obs:
|
||||
return transition
|
||||
@@ -1284,9 +1339,12 @@ class RunningNormalizer:
|
||||
|
||||
# Normalize
|
||||
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
|
||||
obs["observation.state"] = state_normalized
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
processed_obs["observation.state"] = state_normalized
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
@@ -1317,7 +1375,7 @@ class RunningNormalizer:
|
||||
processor = RobotProcessor([
|
||||
ImageProcessor(), # Convert images to float32 [0,1]
|
||||
StateProcessor(), # Convert states to torch tensors
|
||||
ImagePadder(224, 224), # Pad images to standard size
|
||||
ImagePadder(target_height=224, target_width=224), # Pad images to standard size
|
||||
DeviceMover("cuda"), # Move everything to GPU
|
||||
RunningNormalizer(7), # Normalize states
|
||||
], name="CompletePreprocessor")
|
||||
@@ -1330,11 +1388,19 @@ obs = {
|
||||
"pixels": {"cam": np.random.randint(0, 255, (200, 300, 3), dtype=np.uint8)},
|
||||
"agent_pos": np.random.randn(7).astype(np.float32)
|
||||
}
|
||||
transition = (obs, None, 0.0, False, False, {}, {})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {}
|
||||
}
|
||||
|
||||
# Everything is processed and on GPU
|
||||
processed = processor(transition)
|
||||
print(processed[TransitionIndex.OBSERVATION]["observation.images.cam"].device) # cuda:0
|
||||
print(processed[TransitionKey.OBSERVATION]["observation.images.cam"].device) # cuda:0
|
||||
```
|
||||
|
||||
## Solving Real-World Problems with RobotProcessor
|
||||
@@ -1358,22 +1424,25 @@ class KeyRemapper:
|
||||
})
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
# Create new observation with renamed keys
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
renamed_obs = {}
|
||||
for old_key, new_key in self.key_mapping.items():
|
||||
if old_key in obs:
|
||||
renamed_obs[new_key] = obs[old_key]
|
||||
if old_key in processed_obs:
|
||||
renamed_obs[new_key] = processed_obs[old_key]
|
||||
|
||||
# Keep any unmapped keys as-is
|
||||
for key, value in obs.items():
|
||||
for key, value in processed_obs.items():
|
||||
if key not in self.key_mapping:
|
||||
renamed_obs[key] = value
|
||||
|
||||
return (renamed_obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = renamed_obs
|
||||
return new_transition
|
||||
```
|
||||
|
||||
### Workspace-Focused Image Processing
|
||||
@@ -1390,13 +1459,14 @@ class WorkspaceCropper:
|
||||
output_size: Tuple[int, int] = (224, 224)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
for key in list(obs.keys()):
|
||||
processed_obs = dict(obs) # Create a copy
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
img = obs[key]
|
||||
img = processed_obs[key]
|
||||
# Crop to workspace
|
||||
x1, y1, x2, y2 = self.crop_bbox
|
||||
img_cropped = img[:, :, y1:y2, x1:x2]
|
||||
@@ -1407,9 +1477,11 @@ class WorkspaceCropper:
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
obs[key] = img_resized
|
||||
processed_obs[key] = img_resized
|
||||
|
||||
return (obs, *transition[1:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
```
|
||||
|
||||
### Building Complete Pipelines for Different Robots
|
||||
@@ -1471,7 +1543,7 @@ The beauty of this approach is that:
|
||||
|
||||
```python
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
# Always check if observation exists
|
||||
if obs is None:
|
||||
@@ -1496,7 +1568,7 @@ return (modified_obs, None, 0.0, False, False, {}, {})
|
||||
|
||||
```python
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if self.store_previous:
|
||||
# Good - clone to avoid reference issues
|
||||
@@ -1522,7 +1594,7 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
Here's how to use RobotProcessor in a real robot control loop, showing both tuple and batch formats:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from pathlib import Path
|
||||
import time
|
||||
@@ -1545,11 +1617,13 @@ class ActionClipper:
|
||||
max_value: float = 1.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
|
||||
if action is not None:
|
||||
action = torch.clamp(action, self.min_value, self.max_value)
|
||||
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
return transition
|
||||
|
||||
@@ -1578,28 +1652,36 @@ for episode in range(10):
|
||||
|
||||
for step in range(1000):
|
||||
# Create transition with raw observation
|
||||
transition = (obs, None, 0.0, False, False, info, {"step": step})
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: obs,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"step": step}
|
||||
}
|
||||
|
||||
# Preprocess - works with tuple format
|
||||
# Preprocess - works with dictionary format
|
||||
processed_transition = preprocessor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
# Get action from policy
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(processed_obs)
|
||||
|
||||
# Postprocess action
|
||||
action_transition = (
|
||||
processed_obs,
|
||||
action,
|
||||
0.0,
|
||||
False,
|
||||
False,
|
||||
info,
|
||||
{"raw_action": action.clone()} # Store raw action in complementary_data
|
||||
)
|
||||
action_transition = {
|
||||
TransitionKey.OBSERVATION: processed_obs,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: 0.0,
|
||||
TransitionKey.DONE: False,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"raw_action": action.clone()} # Store raw action in complementary_data
|
||||
}
|
||||
processed_action_transition = postprocessor(action_transition)
|
||||
final_action = processed_action_transition[TransitionIndex.ACTION]
|
||||
final_action = processed_action_transition.get(TransitionKey.ACTION)
|
||||
|
||||
# Execute action
|
||||
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
|
||||
@@ -1667,7 +1749,7 @@ Use the full power of `RobotProcessor` for debugging:
|
||||
```python
|
||||
# Enable detailed logging
|
||||
def log_observation_shapes(step_idx: int, transition: EnvTransition):
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
print(f"Step {step_idx} observations:")
|
||||
for key, value in obs.items():
|
||||
@@ -1679,7 +1761,7 @@ processor.register_after_step_hook(log_observation_shapes)
|
||||
|
||||
# Monitor complementary data flow
|
||||
def monitor_complementary_data(step_idx: int, transition: EnvTransition):
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if comp_data:
|
||||
print(f"Step {step_idx} complementary_data: {list(comp_data.keys())}")
|
||||
return None
|
||||
@@ -1688,7 +1770,7 @@ processor.register_before_step_hook(monitor_complementary_data)
|
||||
|
||||
# Validate data integrity
|
||||
def validate_tensors(step_idx: int, transition: EnvTransition):
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
if obs:
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
@@ -1705,21 +1787,21 @@ processor.register_after_step_hook(validate_tensors)
|
||||
|
||||
RobotProcessor provides a powerful, modular approach to data preprocessing in robotics:
|
||||
|
||||
- **Dual format support**: Works seamlessly with both EnvTransition tuples and batch dictionaries
|
||||
- **Automatic format conversion**: Converts between tuple and batch formats as needed
|
||||
- **Dual format support**: Works seamlessly with both EnvTransition dictionaries and batch dictionaries
|
||||
- **Automatic format conversion**: Converts between dictionary and batch formats as needed
|
||||
- **LeRobot integration**: Native support for LeRobotDataset and ReplayBuffer formats
|
||||
- **Clear separation of concerns**: Each transformation is a separate, testable unit
|
||||
- **Proper state management**: Clear distinction between config (JSON) and state (tensors)
|
||||
- **Device-aware**: Seamless GPU/CPU transfers with `.to(device)`
|
||||
- **Inter-step communication**: Use `complementary_data` for passing information
|
||||
- **Easy sharing**: Push to Hugging Face Hub for reproducibility
|
||||
- **Type safety**: Use `TransitionIndex` instead of magic numbers
|
||||
- **Type safety**: Use `TransitionKey` instead of magic numbers
|
||||
- **Debugging tools**: Step through transformations and add monitoring hooks
|
||||
- **Flexible conversion**: Customize `to_transition` and `to_output` functions for specific needs
|
||||
|
||||
Key advantages of the dual format approach:
|
||||
|
||||
- **Environment interaction**: Use tuple format for real-time robot control
|
||||
- **Environment interaction**: Use dictionary format for real-time robot control
|
||||
- **Training/evaluation**: Use batch format for dataset processing and model training
|
||||
- **Seamless integration**: Same processor works with both formats automatically
|
||||
- **Backward compatibility**: Existing code using either format continues to work
|
||||
|
||||
@@ -36,17 +36,25 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
||||
|
||||
# Create processor with observation processor
|
||||
processor = RobotProcessor([VanillaObservationProcessor()])
|
||||
|
||||
# Create transition tuple and process
|
||||
transition = (observations, None, None, None, None, None, None)
|
||||
processed_transition = processor(transition)
|
||||
# Create transition dictionary and process
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observations,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: None,
|
||||
TransitionKey.DONE: None,
|
||||
TransitionKey.TRUNCATED: None,
|
||||
TransitionKey.INFO: None,
|
||||
TransitionKey.COMPLEMENTARY_DATA: None,
|
||||
}
|
||||
result = processor(transition)
|
||||
|
||||
# Return processed observations
|
||||
return processed_transition[TransitionIndex.OBSERVATION]
|
||||
# Extract and return the processed observation
|
||||
return result[TransitionKey.OBSERVATION]
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
|
||||
@@ -32,7 +32,7 @@ from .pipeline import (
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionIndex,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
@@ -54,7 +54,7 @@ __all__ = [
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"StateProcessor",
|
||||
"TransitionIndex",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
]
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -35,30 +35,41 @@ class DeviceProcessor:
|
||||
self.non_blocking = "cuda" in self.device
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
reward = transition[TransitionIndex.REWARD]
|
||||
done = transition[TransitionIndex.DONE]
|
||||
truncated = transition[TransitionIndex.TRUNCATED]
|
||||
info = transition[TransitionIndex.INFO]
|
||||
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Process observation tensors
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
|
||||
new_observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
if action is not None:
|
||||
action = action.to(self.device)
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
|
||||
return (
|
||||
observation,
|
||||
action,
|
||||
reward,
|
||||
done,
|
||||
truncated,
|
||||
info,
|
||||
complementary_data,
|
||||
)
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated.to(
|
||||
self.device, non_blocking=self.non_blocking
|
||||
)
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from collections.abc import Mapping
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -10,7 +10,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
@@ -166,17 +166,14 @@ class NormalizerProcessor:
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition[TransitionIndex.OBSERVATION])
|
||||
action = self._normalize_action(transition[TransitionIndex.ACTION])
|
||||
return (
|
||||
observation,
|
||||
action,
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
config = {
|
||||
@@ -297,17 +294,14 @@ class UnnormalizerProcessor:
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition[TransitionIndex.OBSERVATION])
|
||||
action = self._unnormalize_action(transition[TransitionIndex.ACTION])
|
||||
return (
|
||||
observation,
|
||||
action,
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,7 +36,7 @@ class ImageProcessor:
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
@@ -60,15 +60,9 @@ class ImageProcessor:
|
||||
processed_obs[key] = value
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
@@ -124,7 +118,7 @@ class StateProcessor:
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
@@ -150,15 +144,9 @@ class StateProcessor:
|
||||
del processed_obs["agent_pos"]
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
|
||||
+132
-152
@@ -18,39 +18,42 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
class TransitionIndex(IntEnum):
|
||||
"""Explicit indices for EnvTransition tuple components."""
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
OBSERVATION = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
TRUNCATED = 4
|
||||
INFO = 5
|
||||
COMPLEMENTARY_DATA = 6
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
EnvTransition = tuple[
|
||||
dict[str, Any] | None, # observation
|
||||
Any | torch.Tensor | None, # action
|
||||
float | torch.Tensor | None, # reward
|
||||
bool | torch.Tensor | None, # done
|
||||
bool | torch.Tensor | None, # truncated
|
||||
dict[str, Any] | None, # info
|
||||
dict[str, Any] | None, # complementary_data
|
||||
]
|
||||
class EnvTransition(TypedDict, total=False):
|
||||
"""Environment transition data structure.
|
||||
|
||||
All fields are optional (total=False) to allow flexible usage.
|
||||
"""
|
||||
|
||||
observation: dict[str, Any] | None
|
||||
action: Any | torch.Tensor | None
|
||||
reward: float | torch.Tensor | None
|
||||
done: bool | torch.Tensor | None
|
||||
truncated: bool | torch.Tensor | None
|
||||
info: dict[str, Any] | None
|
||||
complementary_data: dict[str, Any] | None
|
||||
|
||||
|
||||
class ProcessorStepRegistry:
|
||||
@@ -165,10 +168,9 @@ class ProcessorStep(Protocol):
|
||||
|
||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
||||
``EnvTransition`` tuple.
|
||||
``EnvTransition`` dictionary.
|
||||
|
||||
The function is intentionally **strictly positional** – it maps well known
|
||||
keys to the fixed slot order used inside the pipeline. Missing keys are
|
||||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||||
filled with sane defaults (``None`` or ``0.0``/``False``).
|
||||
|
||||
Keys recognised (case-sensitive):
|
||||
@@ -193,15 +195,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||||
|
||||
return (
|
||||
observation,
|
||||
batch.get("action"),
|
||||
batch.get("next.reward", 0.0),
|
||||
batch.get("next.done", False),
|
||||
batch.get("next.truncated", False),
|
||||
batch.get("info", {}),
|
||||
complementary_data,
|
||||
)
|
||||
transition: EnvTransition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: batch.get("action"),
|
||||
TransitionKey.REWARD: batch.get("next.reward", 0.0),
|
||||
TransitionKey.DONE: batch.get("next.done", False),
|
||||
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
|
||||
TransitionKey.INFO: batch.get("info", {}),
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
return transition
|
||||
|
||||
|
||||
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
|
||||
@@ -209,25 +212,16 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
the canonical field names used throughout *LeRobot*.
|
||||
"""
|
||||
|
||||
(
|
||||
observation,
|
||||
action,
|
||||
reward,
|
||||
done,
|
||||
truncated,
|
||||
info,
|
||||
complementary_data,
|
||||
) = transition
|
||||
|
||||
batch = {
|
||||
"action": action,
|
||||
"next.reward": reward,
|
||||
"next.done": done,
|
||||
"next.truncated": truncated,
|
||||
"info": info,
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add padding and task data from complementary_data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||
batch.update(pad_data)
|
||||
@@ -236,6 +230,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
batch["task"] = complementary_data["task"]
|
||||
|
||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
@@ -293,33 +288,35 @@ class RobotProcessor(ModelHubMixin):
|
||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||
"""Process data through all steps.
|
||||
|
||||
The method accepts either the classic EnvTransition tuple or a batch dictionary
|
||||
The method accepts either the classic EnvTransition dict or a batch dictionary
|
||||
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
|
||||
it is first converted to the internal tuple format using to_transition; after all
|
||||
steps are executed the tuple is transformed back into a dict with to_batch and the
|
||||
it is first converted to the internal dict format using to_transition; after all
|
||||
steps are executed the dict is transformed back into a batch dict with to_batch and the
|
||||
result is returned – thereby preserving the caller's original data type.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Returns:
|
||||
The processed data in the same format as the input (tuple or dict).
|
||||
The processed data in the same format as the input (EnvTransition or batch dict).
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid 7-tuple format.
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
|
||||
called_with_batch = isinstance(data, dict)
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
# It's a batch dict, convert it
|
||||
called_with_batch = True
|
||||
transition = self.to_transition(data)
|
||||
else:
|
||||
# It's already an EnvTransition
|
||||
called_with_batch = False
|
||||
transition = data
|
||||
|
||||
transition = self.to_transition(data) if called_with_batch else data
|
||||
|
||||
# Basic validation with helpful error message for tuple input
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||
"truncated, info, complementary_data). "
|
||||
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||
)
|
||||
# Basic validation
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
for hook in self.before_step_hooks:
|
||||
@@ -339,25 +336,28 @@ class RobotProcessor(ModelHubMixin):
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
|
||||
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
|
||||
and preserves the input format in the yielded results.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate results after each step, in the same format as the input.
|
||||
"""
|
||||
called_with_batch = isinstance(data, dict)
|
||||
transition = self.to_transition(data) if called_with_batch else data
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
# It's a batch dict, convert it
|
||||
called_with_batch = True
|
||||
transition = self.to_transition(data)
|
||||
else:
|
||||
# It's already an EnvTransition
|
||||
called_with_batch = False
|
||||
transition = data
|
||||
|
||||
# Basic validation with helpful error message for tuple input
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||
"truncated, info, complementary_data). "
|
||||
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||
)
|
||||
# Basic validation
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
# Yield initial state
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
@@ -684,7 +684,7 @@ class ObservationProcessor:
|
||||
|
||||
Subclasses should override the `observation` method to implement custom observation processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed observation
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -696,7 +696,7 @@ class ObservationProcessor:
|
||||
return observation * self.scale_factor
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific observation processing logic.
|
||||
"""
|
||||
|
||||
@@ -712,10 +712,12 @@ class ObservationProcessor:
|
||||
return observation
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = self.observation(observation)
|
||||
transition = (observation, *transition[TransitionIndex.ACTION :])
|
||||
return transition
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
processed_observation = self.observation(observation)
|
||||
# Create a new transition dict with the processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
||||
return new_transition
|
||||
|
||||
|
||||
class ActionProcessor:
|
||||
@@ -723,7 +725,7 @@ class ActionProcessor:
|
||||
|
||||
Subclasses should override the `action` method to implement custom action processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed action
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -736,7 +738,7 @@ class ActionProcessor:
|
||||
return np.clip(action, self.min_val, self.max_val)
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific action processing logic.
|
||||
"""
|
||||
|
||||
@@ -752,10 +754,12 @@ class ActionProcessor:
|
||||
return action
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
action = self.action(action)
|
||||
transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :])
|
||||
return transition
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
processed_action = self.action(action)
|
||||
# Create a new transition dict with the processed action
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = processed_action
|
||||
return new_transition
|
||||
|
||||
|
||||
class RewardProcessor:
|
||||
@@ -763,7 +767,7 @@ class RewardProcessor:
|
||||
|
||||
Subclasses should override the `reward` method to implement custom reward processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed reward
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -775,7 +779,7 @@ class RewardProcessor:
|
||||
return reward * self.scale_factor
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific reward processing logic.
|
||||
"""
|
||||
|
||||
@@ -791,15 +795,12 @@ class RewardProcessor:
|
||||
return reward
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
reward = transition[TransitionIndex.REWARD]
|
||||
reward = self.reward(reward)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
reward,
|
||||
*transition[TransitionIndex.DONE :],
|
||||
)
|
||||
return transition
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
processed_reward = self.reward(reward)
|
||||
# Create a new transition dict with the processed reward
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = processed_reward
|
||||
return new_transition
|
||||
|
||||
|
||||
class DoneProcessor:
|
||||
@@ -807,7 +808,7 @@ class DoneProcessor:
|
||||
|
||||
Subclasses should override the `done` method to implement custom done flag processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed done flag
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -824,7 +825,7 @@ class DoneProcessor:
|
||||
self.steps = 0
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific done flag processing logic.
|
||||
"""
|
||||
|
||||
@@ -840,16 +841,12 @@ class DoneProcessor:
|
||||
return done
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
done = transition[TransitionIndex.DONE]
|
||||
done = self.done(done)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
done,
|
||||
*transition[TransitionIndex.TRUNCATED :],
|
||||
)
|
||||
return transition
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
processed_done = self.done(done)
|
||||
# Create a new transition dict with the processed done flag
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.DONE] = processed_done
|
||||
return new_transition
|
||||
|
||||
|
||||
class TruncatedProcessor:
|
||||
@@ -857,7 +854,7 @@ class TruncatedProcessor:
|
||||
|
||||
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed truncated flag
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -870,7 +867,7 @@ class TruncatedProcessor:
|
||||
return truncated or some_condition > self.threshold
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific truncated flag processing logic.
|
||||
"""
|
||||
|
||||
@@ -886,17 +883,12 @@ class TruncatedProcessor:
|
||||
return truncated
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition[TransitionIndex.TRUNCATED]
|
||||
truncated = self.truncated(truncated)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
truncated,
|
||||
*transition[TransitionIndex.INFO :],
|
||||
)
|
||||
return transition
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
processed_truncated = self.truncated(truncated)
|
||||
# Create a new transition dict with the processed truncated flag
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
||||
return new_transition
|
||||
|
||||
|
||||
class InfoProcessor:
|
||||
@@ -904,7 +896,7 @@ class InfoProcessor:
|
||||
|
||||
Subclasses should override the `info` method to implement custom info processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed info
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -922,7 +914,7 @@ class InfoProcessor:
|
||||
self.step_count = 0
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||
manipulation, focusing only on the specific info dictionary processing logic.
|
||||
"""
|
||||
|
||||
@@ -938,18 +930,12 @@ class InfoProcessor:
|
||||
return info
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition[TransitionIndex.INFO]
|
||||
info = self.info(info)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
info,
|
||||
*transition[TransitionIndex.COMPLEMENTARY_DATA :],
|
||||
)
|
||||
return transition
|
||||
info = transition.get(TransitionKey.INFO)
|
||||
processed_info = self.info(info)
|
||||
# Create a new transition dict with the processed info
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.INFO] = processed_info
|
||||
return new_transition
|
||||
|
||||
|
||||
class ComplementaryDataProcessor:
|
||||
@@ -957,7 +943,7 @@ class ComplementaryDataProcessor:
|
||||
|
||||
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed complementary data
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
@@ -972,18 +958,12 @@ class ComplementaryDataProcessor:
|
||||
return complementary_data
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
complementary_data = self.complementary_data(complementary_data)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
complementary_data,
|
||||
)
|
||||
return transition
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
processed_complementary_data = self.complementary_data(complementary_data)
|
||||
# Create a new transition dict with the processed complementary data
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
||||
return new_transition
|
||||
|
||||
|
||||
class IdentityProcessor:
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,7 +29,7 @@ class RenameProcessor:
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
@@ -39,15 +39,11 @@ class RenameProcessor:
|
||||
processed_obs[self.rename_map[key]] = value
|
||||
else:
|
||||
processed_obs[key] = value
|
||||
return (
|
||||
processed_obs,
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
|
||||
# Create a new transition with the renamed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
@@ -72,7 +72,7 @@ from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -160,7 +160,7 @@ def rollout(
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
observation = processed_transition[TransitionKey.OBSERVATION]
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
@@ -211,7 +211,7 @@ def rollout(
|
||||
if return_observations:
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
observation = processed_transition[TransitionKey.OBSERVATION]
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
|
||||
@@ -22,7 +22,7 @@ from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
||||
from tests.utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
@@ -53,7 +53,7 @@ def test_factory(env_name):
|
||||
obs_processor = RobotProcessor([VanillaObservationProcessor()])
|
||||
transition = (obs, None, None, None, None, None, None)
|
||||
processed_transition = obs_processor(transition)
|
||||
obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# test image keys are float32 in range [0,1]
|
||||
for key in obs:
|
||||
|
||||
@@ -39,7 +39,7 @@ from lerobot.policies.factory import (
|
||||
)
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor
|
||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
@@ -188,7 +188,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
obs_processor = RobotProcessor([VanillaObservationProcessor()])
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
observation = processed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from lerobot.processor.pipeline import (
|
||||
RobotProcessor,
|
||||
TransitionIndex,
|
||||
TransitionKey,
|
||||
_default_batch_to_transition,
|
||||
_default_transition_to_batch,
|
||||
)
|
||||
@@ -63,27 +63,27 @@ def test_batch_to_transition_observation_grouping():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionIndex.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionIndex.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionIndex.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionIndex.OBSERVATION]
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(
|
||||
transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
)
|
||||
assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionIndex.ACTION] == "action_data"
|
||||
assert transition[TransitionIndex.REWARD] == 1.5
|
||||
assert transition[TransitionIndex.DONE]
|
||||
assert not transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {"episode": 42}
|
||||
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {"episode": 42}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
@@ -94,15 +94,15 @@ def test_transition_to_batch_observation_flattening():
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = (
|
||||
observation_dict, # observation
|
||||
"action_data", # action
|
||||
1.5, # reward
|
||||
True, # done
|
||||
False, # truncated
|
||||
{"episode": 42}, # info
|
||||
{}, # complementary_data
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation_dict,
|
||||
TransitionKey.ACTION: "action_data",
|
||||
TransitionKey.REWARD: 1.5,
|
||||
TransitionKey.DONE: True,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {"episode": 42},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
batch = _default_transition_to_batch(transition)
|
||||
|
||||
@@ -137,14 +137,14 @@ def test_no_observation_keys():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Observation should be None when no observation.* keys
|
||||
assert transition[TransitionIndex.OBSERVATION] is None
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionIndex.ACTION] == "action_data"
|
||||
assert transition[TransitionIndex.REWARD] == 2.0
|
||||
assert not transition[TransitionIndex.DONE]
|
||||
assert transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {"test": "no_obs"}
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
@@ -162,15 +162,15 @@ def test_minimal_batch():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionIndex.ACTION] == "minimal_action"
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionIndex.REWARD] == 0.0
|
||||
assert not transition[TransitionIndex.DONE]
|
||||
assert not transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {}
|
||||
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
@@ -189,13 +189,13 @@ def test_empty_batch():
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# All fields should have defaults
|
||||
assert transition[TransitionIndex.OBSERVATION] is None
|
||||
assert transition[TransitionIndex.ACTION] is None
|
||||
assert transition[TransitionIndex.REWARD] == 0.0
|
||||
assert not transition[TransitionIndex.DONE]
|
||||
assert not transition[TransitionIndex.TRUNCATED]
|
||||
assert transition[TransitionIndex.INFO] == {}
|
||||
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
assert transition[TransitionKey.ACTION] is None
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
@@ -256,33 +256,27 @@ def test_custom_converter():
|
||||
# Custom converter that modifies the reward
|
||||
tr = _default_batch_to_transition(batch)
|
||||
# Double the reward
|
||||
reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0
|
||||
return (
|
||||
tr[TransitionIndex.OBSERVATION],
|
||||
tr[TransitionIndex.ACTION],
|
||||
reward,
|
||||
tr[TransitionIndex.DONE],
|
||||
tr[TransitionIndex.TRUNCATED],
|
||||
tr[TransitionIndex.INFO],
|
||||
tr[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
reward = tr.get(TransitionKey.REWARD, 0.0)
|
||||
new_tr = tr.copy()
|
||||
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
|
||||
return new_tr
|
||||
|
||||
def to_batch(tr):
|
||||
# Custom converter that adds a custom field
|
||||
batch = _default_transition_to_batch(tr)
|
||||
batch["custom_field"] = "custom_value"
|
||||
return batch
|
||||
|
||||
proc = RobotProcessor([], to_transition=to_tr, to_output=to_batch)
|
||||
batch = _dummy_batch()
|
||||
out = proc(batch)
|
||||
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
# Check that custom modifications were applied
|
||||
assert out["next.reward"] == batch["next.reward"] * 2
|
||||
assert out["custom_field"] == "custom_value"
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
"action": torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
}
|
||||
|
||||
# Check that observation.* keys are still preserved
|
||||
original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")}
|
||||
result = processor(batch)
|
||||
|
||||
assert set(original_obs_keys.keys()) == set(output_obs_keys.keys())
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
||||
assert torch.allclose(result["action"], batch["action"])
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
@@ -10,7 +25,22 @@ from lerobot.processor.normalize_processor import (
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_numpy_conversion():
|
||||
@@ -120,10 +150,10 @@ def test_mean_std_normalization(observation_normalizer):
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = observation_normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check mean/std normalization
|
||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
@@ -134,10 +164,10 @@ def test_min_max_normalization(observation_normalizer):
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = observation_normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check min/max normalization to [-1, 1]
|
||||
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
|
||||
@@ -157,10 +187,10 @@ def test_selective_normalization(observation_stats):
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Only image should be normalized
|
||||
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
|
||||
@@ -176,10 +206,10 @@ def test_device_compatibility(observation_stats):
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
assert normalized_obs["observation.image"].device.type == "cuda"
|
||||
|
||||
@@ -220,10 +250,10 @@ def test_state_dict_save_load(observation_normalizer):
|
||||
|
||||
# Test that it works the same
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result1 = observation_normalizer(transition)[0]
|
||||
result2 = new_normalizer(transition)[0]
|
||||
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
assert torch.allclose(result1["observation.image"], result2["observation.image"])
|
||||
|
||||
@@ -271,10 +301,10 @@ def test_mean_std_unnormalization(action_stats_mean_std):
|
||||
)
|
||||
|
||||
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
||||
transition = (None, normalized_action, None, None, None, None, None)
|
||||
transition = create_transition(action=normalized_action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
|
||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
||||
|
||||
# action * std + mean
|
||||
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
|
||||
@@ -290,10 +320,10 @@ def test_min_max_unnormalization(action_stats_min_max):
|
||||
|
||||
# Actions in [-1, 1]
|
||||
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
||||
transition = (None, normalized_action, None, None, None, None, None)
|
||||
transition = create_transition(action=normalized_action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
|
||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
||||
|
||||
# Map from [-1, 1] to [min, max]
|
||||
# (action + 1) / 2 * (max - min) + min
|
||||
@@ -315,10 +345,10 @@ def test_numpy_action_input(action_stats_mean_std):
|
||||
)
|
||||
|
||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
||||
transition = (None, normalized_action, None, None, None, None, None)
|
||||
transition = create_transition(action=normalized_action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
|
||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
||||
|
||||
assert isinstance(unnormalized_action, torch.Tensor)
|
||||
expected = torch.tensor([1.0, -1.0, 1.0])
|
||||
@@ -332,7 +362,7 @@ def test_none_action(action_stats_mean_std):
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
transition = create_transition()
|
||||
result = unnormalizer(transition)
|
||||
|
||||
# Should return transition unchanged
|
||||
@@ -396,23 +426,31 @@ def test_combined_normalization(normalizer_processor):
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = (observation, action, 1.0, False, False, {}, {})
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
reward=1.0,
|
||||
done=False,
|
||||
truncated=False,
|
||||
info={},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
processed_transition = normalizer_processor(transition)
|
||||
|
||||
# Check normalized observations
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
assert torch.allclose(processed_obs["observation.image"], expected_image)
|
||||
|
||||
# Check normalized action
|
||||
processed_action = processed_transition[TransitionIndex.ACTION]
|
||||
processed_action = processed_transition[TransitionKey.ACTION]
|
||||
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
|
||||
assert torch.allclose(processed_action, expected_action)
|
||||
|
||||
# Check other fields remain unchanged
|
||||
assert processed_transition[TransitionIndex.REWARD] == 1.0
|
||||
assert not processed_transition[TransitionIndex.DONE]
|
||||
assert processed_transition[TransitionKey.REWARD] == 1.0
|
||||
assert not processed_transition[TransitionKey.DONE]
|
||||
|
||||
|
||||
def test_processor_from_lerobot_dataset(full_stats):
|
||||
@@ -466,13 +504,21 @@ def test_integration_with_robot_processor(normalizer_processor):
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = (observation, action, 1.0, False, False, {}, {})
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
reward=1.0,
|
||||
done=False,
|
||||
truncated=False,
|
||||
info={},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
processed_transition = robot_processor(transition)
|
||||
|
||||
# Verify the processing worked
|
||||
assert isinstance(processed_transition[TransitionIndex.OBSERVATION], dict)
|
||||
assert isinstance(processed_transition[TransitionIndex.ACTION], torch.Tensor)
|
||||
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
|
||||
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
|
||||
|
||||
|
||||
# Edge case tests
|
||||
@@ -482,7 +528,7 @@ def test_empty_observation():
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
transition = create_transition()
|
||||
result = normalizer(transition)
|
||||
|
||||
assert result == transition
|
||||
@@ -493,11 +539,13 @@ def test_empty_stats():
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
||||
observation = {"observation.image": torch.tensor([0.5])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = normalizer(transition)
|
||||
# Should return observation unchanged since no stats are available
|
||||
assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
|
||||
assert torch.allclose(
|
||||
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
|
||||
)
|
||||
|
||||
|
||||
def test_partial_stats():
|
||||
@@ -507,9 +555,9 @@ def test_partial_stats():
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
observation = {"observation.image": torch.tensor([0.7])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
processed = normalizer(transition)[TransitionIndex.OBSERVATION]
|
||||
processed = normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
assert torch.allclose(processed["observation.image"], observation["observation.image"])
|
||||
|
||||
|
||||
@@ -551,14 +599,25 @@ def test_serialization_roundtrip(full_stats):
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = (observation, action, 1.0, False, False, {}, {})
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
reward=1.0,
|
||||
done=False,
|
||||
truncated=False,
|
||||
info={},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
result1 = original_processor(transition)
|
||||
result2 = new_processor(transition)
|
||||
|
||||
# Compare results
|
||||
assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"])
|
||||
assert torch.allclose(result1[1], result2[1])
|
||||
assert torch.allclose(
|
||||
result1[TransitionKey.OBSERVATION]["observation.image"],
|
||||
result2[TransitionKey.OBSERVATION]["observation.image"],
|
||||
)
|
||||
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
||||
|
||||
# Verify features and norm_map are correctly reconstructed
|
||||
assert new_processor.features.keys() == original_processor.features.keys()
|
||||
|
||||
@@ -23,6 +23,22 @@ from lerobot.processor import (
|
||||
StateProcessor,
|
||||
VanillaObservationProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_process_single_image():
|
||||
@@ -33,10 +49,10 @@ def test_process_single_image():
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert "observation.image" in processed_obs
|
||||
@@ -60,10 +76,10 @@ def test_process_image_dict():
|
||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": {"camera1": image1, "camera2": image2}}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both images were processed
|
||||
assert "observation.images.camera1" in processed_obs
|
||||
@@ -82,10 +98,10 @@ def test_process_batched_image():
|
||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
||||
@@ -98,7 +114,7 @@ def test_invalid_image_format():
|
||||
# Test wrong channel order (channels first)
|
||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected channel-last images"):
|
||||
processor(transition)
|
||||
@@ -111,7 +127,7 @@ def test_invalid_image_dtype():
|
||||
# Test wrong dtype
|
||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
|
||||
processor(transition)
|
||||
@@ -122,10 +138,10 @@ def test_no_pixels_in_observation():
|
||||
processor = ImageProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should preserve other data unchanged
|
||||
assert "other_data" in processed_obs
|
||||
@@ -136,7 +152,7 @@ def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = ImageProcessor()
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
|
||||
assert result == transition
|
||||
@@ -167,10 +183,10 @@ def test_process_environment_state():
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
@@ -188,10 +204,10 @@ def test_process_agent_pos():
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
@@ -211,10 +227,10 @@ def test_process_batched_states():
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
@@ -229,10 +245,10 @@ def test_process_both_states():
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
@@ -251,10 +267,10 @@ def test_no_states_in_observation():
|
||||
processor = StateProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should preserve data unchanged
|
||||
np.testing.assert_array_equal(processed_obs, observation)
|
||||
@@ -275,10 +291,10 @@ def test_complete_observation_processing():
|
||||
"agent_pos": agent_pos,
|
||||
"other_data": "preserve_me",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
@@ -303,10 +319,10 @@ def test_image_only_processing():
|
||||
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
@@ -318,10 +334,10 @@ def test_state_only_processing():
|
||||
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
@@ -332,10 +348,10 @@ def test_empty_observation():
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[0]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert processed_obs == {}
|
||||
|
||||
@@ -369,8 +385,8 @@ def test_equivalent_to_original_function():
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
transition = create_transition(observation=observation)
|
||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
@@ -396,8 +412,8 @@ def test_equivalent_with_image_dict():
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processor_result = processor(transition)[0]
|
||||
transition = create_transition(observation=observation)
|
||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
|
||||
@@ -18,7 +18,7 @@ import json
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -26,6 +26,22 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -45,14 +61,16 @@ class MockStep:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Add a counter to the complementary_data."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
comp_data = {} if comp_data is None else dict(comp_data) # Make a copy
|
||||
|
||||
comp_data[f"{self.name}_counter"] = self.counter
|
||||
self.counter += 1
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
# Create a new transition with updated complementary_data
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Return all JSON-serializable attributes that should be persisted
|
||||
@@ -79,12 +97,14 @@ class MockStepWithoutOptionalMethods:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Multiply reward by multiplier."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
|
||||
if reward is not None:
|
||||
reward = reward * self.multiplier
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = reward * self.multiplier
|
||||
return new_transition
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
return transition
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -105,7 +125,7 @@ class MockStepWithTensorState:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Update running statistics."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
|
||||
if reward is not None:
|
||||
# Update running mean
|
||||
@@ -143,7 +163,7 @@ def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = RobotProcessor()
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
|
||||
assert result == transition
|
||||
@@ -155,15 +175,15 @@ def test_single_step_pipeline():
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
|
||||
assert len(pipeline) == 1
|
||||
assert result[6]["test_step_counter"] == 0 # complementary_data
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
|
||||
|
||||
# Call again to test counter increment
|
||||
result = pipeline(transition)
|
||||
assert result[6]["test_step_counter"] == 1
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1
|
||||
|
||||
|
||||
def test_multiple_steps_pipeline():
|
||||
@@ -172,46 +192,46 @@ def test_multiple_steps_pipeline():
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
|
||||
assert len(pipeline) == 2
|
||||
assert result[6]["step1_counter"] == 0
|
||||
assert result[6]["step2_counter"] == 0
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0
|
||||
|
||||
|
||||
def test_invalid_transition_format():
|
||||
"""Test pipeline with invalid transition format."""
|
||||
pipeline = RobotProcessor([MockStep()])
|
||||
|
||||
# Test with wrong number of elements
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
|
||||
pipeline((None, None, 0.0)) # Only 3 elements
|
||||
# Test with wrong type (tuple instead of dict)
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
|
||||
pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict
|
||||
|
||||
# Test with wrong type
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
|
||||
pipeline("not a tuple")
|
||||
# Test with wrong type (string)
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
|
||||
pipeline("not a dict")
|
||||
|
||||
|
||||
def test_step_through():
|
||||
"""Test step_through method with tuple input."""
|
||||
"""Test step_through method with dict input."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
|
||||
results = list(pipeline.step_through(transition))
|
||||
|
||||
assert len(results) == 3 # Original + 2 steps
|
||||
assert results[0] == transition # Original
|
||||
assert "step1_counter" in results[1][6] # After step1
|
||||
assert "step2_counter" in results[2][6] # After step2
|
||||
assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1
|
||||
assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2
|
||||
|
||||
# Ensure all results are tuples (same format as input)
|
||||
# Ensure all results are dicts (same format as input)
|
||||
for result in results:
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 7
|
||||
assert isinstance(result, dict)
|
||||
assert all(isinstance(k, TransitionKey) for k in result.keys())
|
||||
|
||||
|
||||
def test_step_through_with_dict():
|
||||
@@ -279,7 +299,7 @@ def test_hooks():
|
||||
pipeline.register_before_step_hook(before_hook)
|
||||
pipeline.register_after_step_hook(after_hook)
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
pipeline(transition)
|
||||
|
||||
assert before_calls == [0]
|
||||
@@ -292,15 +312,16 @@ def test_hook_modification():
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
def modify_reward_hook(idx: int, transition: EnvTransition):
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
return (obs, action, 42.0, done, truncated, info, comp_data)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = 42.0
|
||||
return new_transition
|
||||
|
||||
pipeline.register_before_step_hook(modify_reward_hook)
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
|
||||
assert result[2] == 42.0 # reward modified by hook
|
||||
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
|
||||
|
||||
|
||||
def test_reset():
|
||||
@@ -316,7 +337,7 @@ def test_reset():
|
||||
pipeline.register_reset_hook(reset_hook)
|
||||
|
||||
# Make some calls to increment counter
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
pipeline(transition)
|
||||
pipeline(transition)
|
||||
|
||||
@@ -335,7 +356,7 @@ def test_profile_steps():
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
|
||||
profile_results = pipeline.profile_steps(transition, num_runs=10)
|
||||
|
||||
@@ -397,10 +418,10 @@ def test_step_without_optional_methods():
|
||||
step = MockStepWithoutOptionalMethods(multiplier=3.0)
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
transition = (None, None, 2.0, False, False, {}, {})
|
||||
transition = create_transition(reward=2.0)
|
||||
result = pipeline(transition)
|
||||
|
||||
assert result[2] == 6.0 # 2.0 * 3.0
|
||||
assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0
|
||||
|
||||
# Reset should work even if step doesn't implement reset
|
||||
pipeline.reset()
|
||||
@@ -419,7 +440,7 @@ def test_mixed_json_and_tensor_state():
|
||||
|
||||
# Process some transitions with rewards
|
||||
for i in range(10):
|
||||
transition = (None, None, float(i), False, False, {}, {})
|
||||
transition = create_transition(reward=float(i))
|
||||
pipeline(transition)
|
||||
|
||||
# Check state
|
||||
@@ -466,7 +487,7 @@ class MockModuleStep(nn.Module):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process transition and update running mean."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if obs is not None and isinstance(obs, torch.Tensor):
|
||||
# Process observation through linear layer
|
||||
@@ -509,7 +530,7 @@ def test_to_device_with_state_dict():
|
||||
|
||||
# Process some transitions to populate state
|
||||
for i in range(10):
|
||||
transition = (None, None, float(i), False, False, {}, {})
|
||||
transition = create_transition(reward=float(i))
|
||||
pipeline(transition)
|
||||
|
||||
# Check initial device (should be CPU)
|
||||
@@ -551,7 +572,7 @@ def test_to_device_with_module():
|
||||
|
||||
# Process some data
|
||||
obs = torch.randn(2, 5)
|
||||
transition = (obs, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(observation=obs, reward=1.0)
|
||||
pipeline(transition)
|
||||
|
||||
# Check initial device
|
||||
@@ -575,7 +596,7 @@ def test_to_device_with_module():
|
||||
|
||||
# Verify the module still works after transfer
|
||||
obs_cuda = torch.randn(2, 5, device="cuda:0")
|
||||
transition = (obs_cuda, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(observation=obs_cuda, reward=1.0)
|
||||
pipeline(transition) # Should not raise an error
|
||||
|
||||
|
||||
@@ -589,7 +610,7 @@ def test_to_device_mixed_steps():
|
||||
|
||||
# Process some data
|
||||
for i in range(5):
|
||||
transition = (torch.randn(2, 10), None, float(i), False, False, {}, {})
|
||||
transition = create_transition(observation=torch.randn(2, 10), reward=float(i))
|
||||
pipeline(transition)
|
||||
|
||||
# Check initial state
|
||||
@@ -630,7 +651,7 @@ def test_to_device_preserves_functionality():
|
||||
# Process initial data
|
||||
rewards = [1.0, 2.0, 3.0]
|
||||
for r in rewards:
|
||||
transition = (None, None, r, False, False, {}, {})
|
||||
transition = create_transition(reward=r)
|
||||
pipeline(transition)
|
||||
|
||||
# Check state before transfer
|
||||
@@ -645,7 +666,7 @@ def test_to_device_preserves_functionality():
|
||||
assert step.running_count == initial_count
|
||||
|
||||
# Process more data to ensure functionality
|
||||
transition = (None, None, 4.0, False, False, {}, {})
|
||||
transition = create_transition(reward=4.0)
|
||||
_ = pipeline(transition)
|
||||
|
||||
assert step.running_count == 4
|
||||
@@ -700,7 +721,8 @@ class MockNonModuleStepWithState:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process transition using tensor operations."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
obs = transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim:
|
||||
# Perform some tensor operations
|
||||
@@ -718,7 +740,12 @@ class MockNonModuleStepWithState:
|
||||
comp_data[f"{self.name}_mean_output"] = output.mean().item()
|
||||
comp_data[f"{self.name}_steps"] = self.step_count.item()
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
# Return updated transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
return new_transition
|
||||
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
@@ -763,9 +790,9 @@ def test_to_device_non_module_class():
|
||||
# Process some data to populate state
|
||||
for i in range(3):
|
||||
obs = torch.randn(2, 5)
|
||||
transition = (obs, None, float(i), False, False, {}, {})
|
||||
transition = create_transition(observation=obs, reward=float(i))
|
||||
result = pipeline(transition)
|
||||
comp_data = result[6]
|
||||
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert f"{non_module_step.name}_steps" in comp_data
|
||||
|
||||
# Verify all tensors are on CPU initially
|
||||
@@ -811,9 +838,9 @@ def test_to_device_non_module_class():
|
||||
|
||||
# Test that step still works on GPU
|
||||
obs_gpu = torch.randn(2, 5, device="cuda")
|
||||
transition = (obs_gpu, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(observation=obs_gpu, reward=1.0)
|
||||
result = pipeline(transition)
|
||||
comp_data = result[6]
|
||||
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Verify processing worked
|
||||
assert comp_data[f"{non_module_step.name}_steps"] == 4
|
||||
@@ -835,7 +862,7 @@ def test_to_device_module_vs_non_module():
|
||||
|
||||
# Process some data
|
||||
obs = torch.randn(2, 5)
|
||||
transition = (obs, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(observation=obs, reward=1.0)
|
||||
_ = pipeline(transition)
|
||||
|
||||
# Check initial devices
|
||||
@@ -860,7 +887,7 @@ def test_to_device_module_vs_non_module():
|
||||
|
||||
# Process data on GPU
|
||||
obs_gpu = torch.randn(2, 5, device="cuda")
|
||||
transition = (obs_gpu, None, 2.0, False, False, {}, {})
|
||||
transition = create_transition(observation=obs_gpu, reward=2.0)
|
||||
_ = pipeline(transition)
|
||||
|
||||
# Verify both steps processed the data
|
||||
@@ -889,7 +916,8 @@ class MockStepWithNonSerializableParam:
|
||||
self.env = env # Non-serializable parameter (like gym.Env)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Use the env parameter if provided
|
||||
if self.env is not None:
|
||||
@@ -897,10 +925,14 @@ class MockStepWithNonSerializableParam:
|
||||
comp_data[f"{self.name}_env_info"] = str(self.env)
|
||||
|
||||
# Apply multiplier to reward
|
||||
new_transition = transition.copy()
|
||||
if reward is not None:
|
||||
reward = reward * self.multiplier
|
||||
new_transition[TransitionKey.REWARD] = reward * self.multiplier
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
if comp_data:
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Note: env is intentionally NOT included here as it's not serializable
|
||||
@@ -928,13 +960,15 @@ class RegisteredMockStep:
|
||||
device: str = "cpu"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
comp_data = {} if comp_data is None else dict(comp_data)
|
||||
comp_data["registered_step_value"] = self.value
|
||||
comp_data["registered_step_device"] = self.device
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
@@ -993,18 +1027,18 @@ def test_from_pretrained_with_overrides():
|
||||
assert loaded_pipeline.name == "TestOverrides"
|
||||
|
||||
# Test the loaded steps
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(reward=1.0)
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
# Check that overrides were applied
|
||||
comp_data = result[6]
|
||||
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert "env_step_env_info" in comp_data
|
||||
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
|
||||
assert comp_data["registered_step_value"] == 200
|
||||
assert comp_data["registered_step_device"] == "cuda"
|
||||
|
||||
# Check that multiplier override was applied
|
||||
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
|
||||
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier)
|
||||
|
||||
|
||||
def test_from_pretrained_with_partial_overrides():
|
||||
@@ -1024,13 +1058,13 @@ def test_from_pretrained_with_partial_overrides():
|
||||
# Both steps will get the override
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(reward=1.0)
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
# The reward should be affected by both steps, both getting the override
|
||||
# First step: 1.0 * 5.0 = 5.0 (overridden)
|
||||
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
|
||||
assert result[2] == 25.0
|
||||
assert result[TransitionKey.REWARD] == 25.0
|
||||
|
||||
|
||||
def test_from_pretrained_invalid_override_key():
|
||||
@@ -1082,10 +1116,10 @@ def test_from_pretrained_registered_step_override():
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Test that overrides were applied
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
transition = create_transition()
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
comp_data = result[6]
|
||||
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert comp_data["registered_step_value"] == 999
|
||||
assert comp_data["registered_step_device"] == "cuda"
|
||||
|
||||
@@ -1110,13 +1144,13 @@ def test_from_pretrained_mixed_registered_and_unregistered():
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
|
||||
# Test both steps
|
||||
transition = (None, None, 2.0, False, False, {}, {})
|
||||
transition = create_transition(reward=2.0)
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
comp_data = result[6]
|
||||
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
|
||||
assert comp_data["registered_step_value"] == 777
|
||||
assert result[2] == 8.0 # 2.0 * 4.0
|
||||
assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0
|
||||
|
||||
|
||||
def test_from_pretrained_no_overrides():
|
||||
@@ -1133,10 +1167,10 @@ def test_from_pretrained_no_overrides():
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Test that the step works (env will be None)
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(reward=1.0)
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
assert result[2] == 3.0 # 1.0 * 3.0
|
||||
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0
|
||||
|
||||
|
||||
def test_from_pretrained_empty_overrides():
|
||||
@@ -1153,10 +1187,10 @@ def test_from_pretrained_empty_overrides():
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Test that the step works normally
|
||||
transition = (None, None, 1.0, False, False, {}, {})
|
||||
transition = create_transition(reward=1.0)
|
||||
result = loaded_pipeline(transition)
|
||||
|
||||
assert result[2] == 2.0
|
||||
assert result[TransitionKey.REWARD] == 2.0
|
||||
|
||||
|
||||
def test_from_pretrained_override_instantiation_error():
|
||||
@@ -1185,7 +1219,7 @@ def test_from_pretrained_with_state_and_overrides():
|
||||
|
||||
# Process some data to create state
|
||||
for i in range(10):
|
||||
transition = (None, None, float(i), False, False, {}, {})
|
||||
transition = create_transition(reward=float(i))
|
||||
pipeline(transition)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
@@ -13,14 +13,28 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionIndex
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_basic_renaming():
|
||||
@@ -36,10 +50,10 @@ def test_basic_renaming():
|
||||
"old_key2": np.array([3.0, 4.0]),
|
||||
"unchanged_key": "keep_me",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renamed keys
|
||||
assert "new_key1" in processed_obs
|
||||
@@ -63,10 +77,10 @@ def test_empty_rename_map():
|
||||
"key1": torch.tensor([1.0]),
|
||||
"key2": "value2",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# All keys should be unchanged
|
||||
assert processed_obs.keys() == observation.keys()
|
||||
@@ -78,7 +92,7 @@ def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
|
||||
# Should return transition unchanged
|
||||
@@ -98,10 +112,10 @@ def test_overlapping_rename():
|
||||
"b": 2,
|
||||
"x": 3,
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that renaming happens correctly
|
||||
assert "a" not in processed_obs
|
||||
@@ -124,10 +138,10 @@ def test_partial_rename():
|
||||
"reward": 1.0,
|
||||
"info": {"episode": 1},
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renamed keys
|
||||
assert "observation.proprio_state" in processed_obs
|
||||
@@ -178,10 +192,12 @@ def test_integration_with_robot_processor():
|
||||
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
|
||||
"other_data": "preserve_me",
|
||||
}
|
||||
transition = (observation, None, 0.5, False, False, {}, {})
|
||||
transition = create_transition(
|
||||
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
|
||||
)
|
||||
|
||||
result = pipeline(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renaming worked through pipeline
|
||||
assert "observation.state" in processed_obs
|
||||
@@ -191,8 +207,8 @@ def test_integration_with_robot_processor():
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
|
||||
# Check other transition elements unchanged
|
||||
assert result[TransitionIndex.REWARD] == 0.5
|
||||
assert result[TransitionIndex.DONE] is False
|
||||
assert result[TransitionKey.REWARD] == 0.5
|
||||
assert result[TransitionKey.DONE] is False
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
@@ -229,10 +245,10 @@ def test_save_and_load_pretrained():
|
||||
|
||||
# Test functionality after loading
|
||||
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = loaded_pipeline(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
@@ -306,17 +322,17 @@ def test_chained_rename_processors():
|
||||
"img": "image_data",
|
||||
"extra": "keep_me",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Step through to see intermediate results
|
||||
results = list(pipeline.step_through(transition))
|
||||
|
||||
# After first processor
|
||||
assert "agent_position" in results[1][TransitionIndex.OBSERVATION]
|
||||
assert "camera_image" in results[1][TransitionIndex.OBSERVATION]
|
||||
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
|
||||
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
|
||||
|
||||
# After second processor
|
||||
final_obs = results[2][TransitionIndex.OBSERVATION]
|
||||
final_obs = results[2][TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in final_obs
|
||||
assert "observation.image" in final_obs
|
||||
assert final_obs["extra"] == "keep_me"
|
||||
@@ -343,10 +359,10 @@ def test_nested_observation_rename():
|
||||
"observation.proprio": torch.randn(7),
|
||||
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renames
|
||||
assert "observation.camera.left_view" in processed_obs
|
||||
@@ -378,10 +394,10 @@ def test_value_types_preserved():
|
||||
"old_dict": {"nested": "value"},
|
||||
"old_list": [1, 2, 3],
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that values and types are preserved
|
||||
assert torch.equal(processed_obs["new_tensor"], tensor_value)
|
||||
|
||||
Reference in New Issue
Block a user