mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
chore (docs): add docstring for processor
This commit is contained in:
@@ -36,11 +36,11 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||
|
||||
# Create processor with observation processor
|
||||
processor = RobotProcessor([ObservationProcessor()])
|
||||
processor = RobotProcessor([VanillaObservationProcessor()])
|
||||
|
||||
# Create transition tuple and process
|
||||
transition = (observations, None, None, None, None, None, None)
|
||||
|
||||
@@ -16,19 +16,35 @@
|
||||
from .normalize_processor import NormalizationProcessor
|
||||
from .observation_processor import (
|
||||
ImageProcessor,
|
||||
ObservationProcessor,
|
||||
StateProcessor,
|
||||
VanillaObservationProcessor,
|
||||
)
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
DoneProcessor,
|
||||
EnvTransition,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .pipeline import EnvTransition, ProcessorStep, RobotProcessor
|
||||
from .rename_processor import RenameProcessor
|
||||
|
||||
__all__ = [
|
||||
"RobotProcessor",
|
||||
"ProcessorStep",
|
||||
"ActionProcessor",
|
||||
"DoneProcessor",
|
||||
"EnvTransition",
|
||||
"ImageProcessor",
|
||||
"StateProcessor",
|
||||
"ObservationProcessor",
|
||||
"InfoProcessor",
|
||||
"NormalizationProcessor",
|
||||
"ObservationProcessor",
|
||||
"ProcessorStep",
|
||||
"RenameProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"StateProcessor",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
]
|
||||
|
||||
@@ -181,7 +181,7 @@ class StateProcessor:
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class ObservationProcessor:
|
||||
class VanillaObservationProcessor:
|
||||
"""Complete observation processor that combines image and state processing.
|
||||
|
||||
This processor replicates the functionality of the original preprocess_observation
|
||||
|
||||
@@ -467,7 +467,35 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
|
||||
class ObservationProcessor:
|
||||
"""Base class for processors that modify only the observation component of a transition.
|
||||
|
||||
Subclasses should override the `observation` method to implement custom observation processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed observation
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class MyObservationScaler(ObservationProcessor):
|
||||
def __init__(self, scale_factor):
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def observation(self, observation):
|
||||
return observation * self.scale_factor
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
manipulation, focusing only on the specific observation processing logic.
|
||||
"""
|
||||
|
||||
def observation(self, observation):
|
||||
"""Process the observation component.
|
||||
|
||||
Args:
|
||||
observation: The observation to process
|
||||
|
||||
Returns:
|
||||
The processed observation
|
||||
"""
|
||||
return observation
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -478,7 +506,36 @@ class ObservationProcessor:
|
||||
|
||||
|
||||
class ActionProcessor:
|
||||
"""Base class for processors that modify only the action component of a transition.
|
||||
|
||||
Subclasses should override the `action` method to implement custom action processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed action
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class ActionClipping(ActionProcessor):
|
||||
def __init__(self, min_val, max_val):
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
|
||||
def action(self, action):
|
||||
return np.clip(action, self.min_val, self.max_val)
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
manipulation, focusing only on the specific action processing logic.
|
||||
"""
|
||||
|
||||
def action(self, action):
|
||||
"""Process the action component.
|
||||
|
||||
Args:
|
||||
action: The action to process
|
||||
|
||||
Returns:
|
||||
The processed action
|
||||
"""
|
||||
return action
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -489,7 +546,35 @@ class ActionProcessor:
|
||||
|
||||
|
||||
class RewardProcessor:
|
||||
"""Base class for processors that modify only the reward component of a transition.
|
||||
|
||||
Subclasses should override the `reward` method to implement custom reward processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed reward
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class RewardScaler(RewardProcessor):
|
||||
def __init__(self, scale_factor):
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def reward(self, reward):
|
||||
return reward * self.scale_factor
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
manipulation, focusing only on the specific reward processing logic.
|
||||
"""
|
||||
|
||||
def reward(self, reward):
|
||||
"""Process the reward component.
|
||||
|
||||
Args:
|
||||
reward: The reward to process
|
||||
|
||||
Returns:
|
||||
The processed reward
|
||||
"""
|
||||
return reward
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -505,7 +590,40 @@ class RewardProcessor:
|
||||
|
||||
|
||||
class DoneProcessor:
|
||||
"""Base class for processors that modify only the done flag of a transition.
|
||||
|
||||
Subclasses should override the `done` method to implement custom done flag processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed done flag
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class TimeoutDone(DoneProcessor):
|
||||
def __init__(self, max_steps):
|
||||
self.steps = 0
|
||||
self.max_steps = max_steps
|
||||
|
||||
def done(self, done):
|
||||
self.steps += 1
|
||||
return done or self.steps >= self.max_steps
|
||||
|
||||
def reset(self):
|
||||
self.steps = 0
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
manipulation, focusing only on the specific done flag processing logic.
|
||||
"""
|
||||
|
||||
def done(self, done):
|
||||
"""Process the done flag.
|
||||
|
||||
Args:
|
||||
done: The done flag to process
|
||||
|
||||
Returns:
|
||||
The processed done flag
|
||||
"""
|
||||
return done
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -522,7 +640,36 @@ class DoneProcessor:
|
||||
|
||||
|
||||
class TruncatedProcessor:
|
||||
"""Base class for processors that modify only the truncated flag of a transition.
|
||||
|
||||
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed truncated flag
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class EarlyTruncation(TruncatedProcessor):
|
||||
def __init__(self, threshold):
|
||||
self.threshold = threshold
|
||||
|
||||
def truncated(self, truncated):
|
||||
# Additional truncation condition
|
||||
return truncated or some_condition > self.threshold
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
manipulation, focusing only on the specific truncated flag processing logic.
|
||||
"""
|
||||
|
||||
def truncated(self, truncated):
|
||||
"""Process the truncated flag.
|
||||
|
||||
Args:
|
||||
truncated: The truncated flag to process
|
||||
|
||||
Returns:
|
||||
The processed truncated flag
|
||||
"""
|
||||
return truncated
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -540,7 +687,41 @@ class TruncatedProcessor:
|
||||
|
||||
|
||||
class InfoProcessor:
|
||||
"""Base class for processors that modify only the info dictionary of a transition.
|
||||
|
||||
Subclasses should override the `info` method to implement custom info processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed info
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class InfoAugmenter(InfoProcessor):
|
||||
def __init__(self):
|
||||
self.step_count = 0
|
||||
|
||||
def info(self, info):
|
||||
info = info.copy() # Create a copy to avoid modifying the original
|
||||
info['steps'] = self.step_count
|
||||
self.step_count += 1
|
||||
return info
|
||||
|
||||
def reset(self):
|
||||
self.step_count = 0
|
||||
```
|
||||
|
||||
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||
manipulation, focusing only on the specific info dictionary processing logic.
|
||||
"""
|
||||
|
||||
def info(self, info):
|
||||
"""Process the info dictionary.
|
||||
|
||||
Args:
|
||||
info: The info dictionary to process
|
||||
|
||||
Returns:
|
||||
The processed info dictionary
|
||||
"""
|
||||
return info
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
|
||||
@@ -20,8 +20,8 @@ import torch
|
||||
|
||||
from lerobot.processor.observation_processor import (
|
||||
ImageProcessor,
|
||||
ObservationProcessor,
|
||||
StateProcessor,
|
||||
VanillaObservationProcessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -262,7 +262,7 @@ def test_no_states_in_observation():
|
||||
|
||||
def test_complete_observation_processing():
|
||||
"""Test processing a complete observation with both images and states."""
|
||||
processor = ObservationProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create mock data
|
||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -299,7 +299,7 @@ def test_complete_observation_processing():
|
||||
|
||||
def test_image_only_processing():
|
||||
"""Test processing observation with only images."""
|
||||
processor = ObservationProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
@@ -314,7 +314,7 @@ def test_image_only_processing():
|
||||
|
||||
def test_state_only_processing():
|
||||
"""Test processing observation with only states."""
|
||||
processor = ObservationProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
@@ -329,7 +329,7 @@ def test_state_only_processing():
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test processing empty observation."""
|
||||
processor = ObservationProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
@@ -344,7 +344,7 @@ def test_custom_sub_processors():
|
||||
"""Test ObservationProcessor with custom sub-processors."""
|
||||
image_proc = ImageProcessor()
|
||||
state_proc = StateProcessor()
|
||||
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
||||
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
||||
|
||||
# Should use the provided processors
|
||||
assert processor.image_processor is image_proc
|
||||
@@ -356,7 +356,7 @@ def test_equivalent_to_original_function():
|
||||
# Import the original function for comparison
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = ObservationProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create test data similar to what the original function expects
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
@@ -383,7 +383,7 @@ def test_equivalent_with_image_dict():
|
||||
"""Test equivalence with dictionary of images."""
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = ObservationProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create test data with multiple cameras
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
|
||||
Reference in New Issue
Block a user