chore (docs): add docstring for processor

This commit is contained in:
Adil Zouitine
2025-07-04 11:07:18 +02:00
parent 453e0a995f
commit 9f33791b19
5 changed files with 214 additions and 17 deletions
+2 -2
View File
@@ -36,11 +36,11 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.observation_processor import VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
# Create processor with observation processor
processor = RobotProcessor([ObservationProcessor()])
processor = RobotProcessor([VanillaObservationProcessor()])
# Create transition tuple and process
transition = (observations, None, None, None, None, None, None)
+22 -6
View File
@@ -16,19 +16,35 @@
from .normalize_processor import NormalizationProcessor
from .observation_processor import (
ImageProcessor,
ObservationProcessor,
StateProcessor,
VanillaObservationProcessor,
)
from .pipeline import (
ActionProcessor,
DoneProcessor,
EnvTransition,
InfoProcessor,
ObservationProcessor,
ProcessorStep,
RewardProcessor,
RobotProcessor,
TruncatedProcessor,
)
from .pipeline import EnvTransition, ProcessorStep, RobotProcessor
from .rename_processor import RenameProcessor
__all__ = [
"RobotProcessor",
"ProcessorStep",
"ActionProcessor",
"DoneProcessor",
"EnvTransition",
"ImageProcessor",
"StateProcessor",
"ObservationProcessor",
"InfoProcessor",
"NormalizationProcessor",
"ObservationProcessor",
"ProcessorStep",
"RenameProcessor",
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TruncatedProcessor",
"VanillaObservationProcessor",
]
@@ -181,7 +181,7 @@ class StateProcessor:
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
class ObservationProcessor:
class VanillaObservationProcessor:
"""Complete observation processor that combines image and state processing.
This processor replicates the functionality of the original preprocess_observation
+181
View File
@@ -467,7 +467,35 @@ class RobotProcessor(ModelHubMixin):
class ObservationProcessor:
"""Base class for processors that modify only the observation component of a transition.
Subclasses should override the `observation` method to implement custom observation processing.
This class handles the boilerplate of extracting and reinserting the processed observation
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class MyObservationScaler(ObservationProcessor):
def __init__(self, scale_factor):
self.scale_factor = scale_factor
def observation(self, observation):
return observation * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific observation processing logic.
"""
def observation(self, observation):
"""Process the observation component.
Args:
observation: The observation to process
Returns:
The processed observation
"""
return observation
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -478,7 +506,36 @@ class ObservationProcessor:
class ActionProcessor:
"""Base class for processors that modify only the action component of a transition.
Subclasses should override the `action` method to implement custom action processing.
This class handles the boilerplate of extracting and reinserting the processed action
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class ActionClipping(ActionProcessor):
def __init__(self, min_val, max_val):
self.min_val = min_val
self.max_val = max_val
def action(self, action):
return np.clip(action, self.min_val, self.max_val)
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific action processing logic.
"""
def action(self, action):
"""Process the action component.
Args:
action: The action to process
Returns:
The processed action
"""
return action
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -489,7 +546,35 @@ class ActionProcessor:
class RewardProcessor:
"""Base class for processors that modify only the reward component of a transition.
Subclasses should override the `reward` method to implement custom reward processing.
This class handles the boilerplate of extracting and reinserting the processed reward
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class RewardScaler(RewardProcessor):
def __init__(self, scale_factor):
self.scale_factor = scale_factor
def reward(self, reward):
return reward * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific reward processing logic.
"""
def reward(self, reward):
"""Process the reward component.
Args:
reward: The reward to process
Returns:
The processed reward
"""
return reward
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -505,7 +590,40 @@ class RewardProcessor:
class DoneProcessor:
"""Base class for processors that modify only the done flag of a transition.
Subclasses should override the `done` method to implement custom done flag processing.
This class handles the boilerplate of extracting and reinserting the processed done flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class TimeoutDone(DoneProcessor):
def __init__(self, max_steps):
self.steps = 0
self.max_steps = max_steps
def done(self, done):
self.steps += 1
return done or self.steps >= self.max_steps
def reset(self):
self.steps = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific done flag processing logic.
"""
def done(self, done):
"""Process the done flag.
Args:
done: The done flag to process
Returns:
The processed done flag
"""
return done
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -522,7 +640,36 @@ class DoneProcessor:
class TruncatedProcessor:
"""Base class for processors that modify only the truncated flag of a transition.
Subclasses should override the `truncated` method to implement custom truncated flag processing.
This class handles the boilerplate of extracting and reinserting the processed truncated flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class EarlyTruncation(TruncatedProcessor):
def __init__(self, threshold):
self.threshold = threshold
def truncated(self, truncated):
# Additional truncation condition
return truncated or some_condition > self.threshold
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific truncated flag processing logic.
"""
def truncated(self, truncated):
"""Process the truncated flag.
Args:
truncated: The truncated flag to process
Returns:
The processed truncated flag
"""
return truncated
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -540,7 +687,41 @@ class TruncatedProcessor:
class InfoProcessor:
"""Base class for processors that modify only the info dictionary of a transition.
Subclasses should override the `info` method to implement custom info processing.
This class handles the boilerplate of extracting and reinserting the processed info
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class InfoAugmenter(InfoProcessor):
def __init__(self):
self.step_count = 0
def info(self, info):
info = info.copy() # Create a copy to avoid modifying the original
info['steps'] = self.step_count
self.step_count += 1
return info
def reset(self):
self.step_count = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific info dictionary processing logic.
"""
def info(self, info):
"""Process the info dictionary.
Args:
info: The info dictionary to process
Returns:
The processed info dictionary
"""
return info
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -20,8 +20,8 @@ import torch
from lerobot.processor.observation_processor import (
ImageProcessor,
ObservationProcessor,
StateProcessor,
VanillaObservationProcessor,
)
@@ -262,7 +262,7 @@ def test_no_states_in_observation():
def test_complete_observation_processing():
"""Test processing a complete observation with both images and states."""
processor = ObservationProcessor()
processor = VanillaObservationProcessor()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -299,7 +299,7 @@ def test_complete_observation_processing():
def test_image_only_processing():
"""Test processing observation with only images."""
processor = ObservationProcessor()
processor = VanillaObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
@@ -314,7 +314,7 @@ def test_image_only_processing():
def test_state_only_processing():
"""Test processing observation with only states."""
processor = ObservationProcessor()
processor = VanillaObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
@@ -329,7 +329,7 @@ def test_state_only_processing():
def test_empty_observation():
"""Test processing empty observation."""
processor = ObservationProcessor()
processor = VanillaObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
@@ -344,7 +344,7 @@ def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
@@ -356,7 +356,7 @@ def test_equivalent_to_original_function():
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
processor = VanillaObservationProcessor()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
@@ -383,7 +383,7 @@ def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
processor = VanillaObservationProcessor()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)