chore (docs): add docstring for processor

This commit is contained in:
Adil Zouitine
2025-07-04 11:07:18 +02:00
parent 453e0a995f
commit 9f33791b19
5 changed files with 214 additions and 17 deletions
+2 -2
View File
@@ -36,11 +36,11 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
Returns: Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
""" """
from lerobot.processor.observation_processor import ObservationProcessor from lerobot.processor.observation_processor import VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
# Create processor with observation processor # Create processor with observation processor
processor = RobotProcessor([ObservationProcessor()]) processor = RobotProcessor([VanillaObservationProcessor()])
# Create transition tuple and process # Create transition tuple and process
transition = (observations, None, None, None, None, None, None) transition = (observations, None, None, None, None, None, None)
+22 -6
View File
@@ -16,19 +16,35 @@
from .normalize_processor import NormalizationProcessor from .normalize_processor import NormalizationProcessor
from .observation_processor import ( from .observation_processor import (
ImageProcessor, ImageProcessor,
ObservationProcessor,
StateProcessor, StateProcessor,
VanillaObservationProcessor,
)
from .pipeline import (
ActionProcessor,
DoneProcessor,
EnvTransition,
InfoProcessor,
ObservationProcessor,
ProcessorStep,
RewardProcessor,
RobotProcessor,
TruncatedProcessor,
) )
from .pipeline import EnvTransition, ProcessorStep, RobotProcessor
from .rename_processor import RenameProcessor from .rename_processor import RenameProcessor
__all__ = [ __all__ = [
"RobotProcessor", "ActionProcessor",
"ProcessorStep", "DoneProcessor",
"EnvTransition", "EnvTransition",
"ImageProcessor", "ImageProcessor",
"StateProcessor", "InfoProcessor",
"ObservationProcessor",
"NormalizationProcessor", "NormalizationProcessor",
"ObservationProcessor",
"ProcessorStep",
"RenameProcessor", "RenameProcessor",
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TruncatedProcessor",
"VanillaObservationProcessor",
] ]
@@ -181,7 +181,7 @@ class StateProcessor:
@dataclass @dataclass
@ProcessorStepRegistry.register(name="observation_processor") @ProcessorStepRegistry.register(name="observation_processor")
class ObservationProcessor: class VanillaObservationProcessor:
"""Complete observation processor that combines image and state processing. """Complete observation processor that combines image and state processing.
This processor replicates the functionality of the original preprocess_observation This processor replicates the functionality of the original preprocess_observation
+181
View File
@@ -467,7 +467,35 @@ class RobotProcessor(ModelHubMixin):
class ObservationProcessor: class ObservationProcessor:
"""Base class for processors that modify only the observation component of a transition.
Subclasses should override the `observation` method to implement custom observation processing.
This class handles the boilerplate of extracting and reinserting the processed observation
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class MyObservationScaler(ObservationProcessor):
def __init__(self, scale_factor):
self.scale_factor = scale_factor
def observation(self, observation):
return observation * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific observation processing logic.
"""
def observation(self, observation): def observation(self, observation):
"""Process the observation component.
Args:
observation: The observation to process
Returns:
The processed observation
"""
return observation return observation
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -478,7 +506,36 @@ class ObservationProcessor:
class ActionProcessor: class ActionProcessor:
"""Base class for processors that modify only the action component of a transition.
Subclasses should override the `action` method to implement custom action processing.
This class handles the boilerplate of extracting and reinserting the processed action
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class ActionClipping(ActionProcessor):
def __init__(self, min_val, max_val):
self.min_val = min_val
self.max_val = max_val
def action(self, action):
return np.clip(action, self.min_val, self.max_val)
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific action processing logic.
"""
def action(self, action): def action(self, action):
"""Process the action component.
Args:
action: The action to process
Returns:
The processed action
"""
return action return action
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -489,7 +546,35 @@ class ActionProcessor:
class RewardProcessor: class RewardProcessor:
"""Base class for processors that modify only the reward component of a transition.
Subclasses should override the `reward` method to implement custom reward processing.
This class handles the boilerplate of extracting and reinserting the processed reward
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class RewardScaler(RewardProcessor):
def __init__(self, scale_factor):
self.scale_factor = scale_factor
def reward(self, reward):
return reward * self.scale_factor
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific reward processing logic.
"""
def reward(self, reward): def reward(self, reward):
"""Process the reward component.
Args:
reward: The reward to process
Returns:
The processed reward
"""
return reward return reward
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -505,7 +590,40 @@ class RewardProcessor:
class DoneProcessor: class DoneProcessor:
"""Base class for processors that modify only the done flag of a transition.
Subclasses should override the `done` method to implement custom done flag processing.
This class handles the boilerplate of extracting and reinserting the processed done flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class TimeoutDone(DoneProcessor):
def __init__(self, max_steps):
self.steps = 0
self.max_steps = max_steps
def done(self, done):
self.steps += 1
return done or self.steps >= self.max_steps
def reset(self):
self.steps = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific done flag processing logic.
"""
def done(self, done): def done(self, done):
"""Process the done flag.
Args:
done: The done flag to process
Returns:
The processed done flag
"""
return done return done
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -522,7 +640,36 @@ class DoneProcessor:
class TruncatedProcessor: class TruncatedProcessor:
"""Base class for processors that modify only the truncated flag of a transition.
Subclasses should override the `truncated` method to implement custom truncated flag processing.
This class handles the boilerplate of extracting and reinserting the processed truncated flag
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class EarlyTruncation(TruncatedProcessor):
def __init__(self, threshold):
self.threshold = threshold
def truncated(self, truncated):
# Additional truncation condition
return truncated or some_condition > self.threshold
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific truncated flag processing logic.
"""
def truncated(self, truncated): def truncated(self, truncated):
"""Process the truncated flag.
Args:
truncated: The truncated flag to process
Returns:
The processed truncated flag
"""
return truncated return truncated
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -540,7 +687,41 @@ class TruncatedProcessor:
class InfoProcessor: class InfoProcessor:
"""Base class for processors that modify only the info dictionary of a transition.
Subclasses should override the `info` method to implement custom info processing.
This class handles the boilerplate of extracting and reinserting the processed info
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
Example:
```python
class InfoAugmenter(InfoProcessor):
def __init__(self):
self.step_count = 0
def info(self, info):
info = info.copy() # Create a copy to avoid modifying the original
info['steps'] = self.step_count
self.step_count += 1
return info
def reset(self):
self.step_count = 0
```
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
manipulation, focusing only on the specific info dictionary processing logic.
"""
def info(self, info): def info(self, info):
"""Process the info dictionary.
Args:
info: The info dictionary to process
Returns:
The processed info dictionary
"""
return info return info
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -20,8 +20,8 @@ import torch
from lerobot.processor.observation_processor import ( from lerobot.processor.observation_processor import (
ImageProcessor, ImageProcessor,
ObservationProcessor,
StateProcessor, StateProcessor,
VanillaObservationProcessor,
) )
@@ -262,7 +262,7 @@ def test_no_states_in_observation():
def test_complete_observation_processing(): def test_complete_observation_processing():
"""Test processing a complete observation with both images and states.""" """Test processing a complete observation with both images and states."""
processor = ObservationProcessor() processor = VanillaObservationProcessor()
# Create mock data # Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -299,7 +299,7 @@ def test_complete_observation_processing():
def test_image_only_processing(): def test_image_only_processing():
"""Test processing observation with only images.""" """Test processing observation with only images."""
processor = ObservationProcessor() processor = VanillaObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image} observation = {"pixels": image}
@@ -314,7 +314,7 @@ def test_image_only_processing():
def test_state_only_processing(): def test_state_only_processing():
"""Test processing observation with only states.""" """Test processing observation with only states."""
processor = ObservationProcessor() processor = VanillaObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32) agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos} observation = {"agent_pos": agent_pos}
@@ -329,7 +329,7 @@ def test_state_only_processing():
def test_empty_observation(): def test_empty_observation():
"""Test processing empty observation.""" """Test processing empty observation."""
processor = ObservationProcessor() processor = VanillaObservationProcessor()
observation = {} observation = {}
transition = (observation, None, None, None, None, None, None) transition = (observation, None, None, None, None, None, None)
@@ -344,7 +344,7 @@ def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors.""" """Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor() image_proc = ImageProcessor()
state_proc = StateProcessor() state_proc = StateProcessor()
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc) processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors # Should use the provided processors
assert processor.image_processor is image_proc assert processor.image_processor is image_proc
@@ -356,7 +356,7 @@ def test_equivalent_to_original_function():
# Import the original function for comparison # Import the original function for comparison
from lerobot.envs.utils import preprocess_observation from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor() processor = VanillaObservationProcessor()
# Create test data similar to what the original function expects # Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
@@ -383,7 +383,7 @@ def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images.""" """Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor() processor = VanillaObservationProcessor()
# Create test data with multiple cameras # Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)