mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
chore (docs): add docstring for processor
This commit is contained in:
@@ -36,11 +36,11 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||||
"""
|
"""
|
||||||
from lerobot.processor.observation_processor import ObservationProcessor
|
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||||
|
|
||||||
# Create processor with observation processor
|
# Create processor with observation processor
|
||||||
processor = RobotProcessor([ObservationProcessor()])
|
processor = RobotProcessor([VanillaObservationProcessor()])
|
||||||
|
|
||||||
# Create transition tuple and process
|
# Create transition tuple and process
|
||||||
transition = (observations, None, None, None, None, None, None)
|
transition = (observations, None, None, None, None, None, None)
|
||||||
|
|||||||
@@ -16,19 +16,35 @@
|
|||||||
from .normalize_processor import NormalizationProcessor
|
from .normalize_processor import NormalizationProcessor
|
||||||
from .observation_processor import (
|
from .observation_processor import (
|
||||||
ImageProcessor,
|
ImageProcessor,
|
||||||
ObservationProcessor,
|
|
||||||
StateProcessor,
|
StateProcessor,
|
||||||
|
VanillaObservationProcessor,
|
||||||
|
)
|
||||||
|
from .pipeline import (
|
||||||
|
ActionProcessor,
|
||||||
|
DoneProcessor,
|
||||||
|
EnvTransition,
|
||||||
|
InfoProcessor,
|
||||||
|
ObservationProcessor,
|
||||||
|
ProcessorStep,
|
||||||
|
RewardProcessor,
|
||||||
|
RobotProcessor,
|
||||||
|
TruncatedProcessor,
|
||||||
)
|
)
|
||||||
from .pipeline import EnvTransition, ProcessorStep, RobotProcessor
|
|
||||||
from .rename_processor import RenameProcessor
|
from .rename_processor import RenameProcessor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RobotProcessor",
|
"ActionProcessor",
|
||||||
"ProcessorStep",
|
"DoneProcessor",
|
||||||
"EnvTransition",
|
"EnvTransition",
|
||||||
"ImageProcessor",
|
"ImageProcessor",
|
||||||
"StateProcessor",
|
"InfoProcessor",
|
||||||
"ObservationProcessor",
|
|
||||||
"NormalizationProcessor",
|
"NormalizationProcessor",
|
||||||
|
"ObservationProcessor",
|
||||||
|
"ProcessorStep",
|
||||||
"RenameProcessor",
|
"RenameProcessor",
|
||||||
|
"RewardProcessor",
|
||||||
|
"RobotProcessor",
|
||||||
|
"StateProcessor",
|
||||||
|
"TruncatedProcessor",
|
||||||
|
"VanillaObservationProcessor",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class StateProcessor:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="observation_processor")
|
@ProcessorStepRegistry.register(name="observation_processor")
|
||||||
class ObservationProcessor:
|
class VanillaObservationProcessor:
|
||||||
"""Complete observation processor that combines image and state processing.
|
"""Complete observation processor that combines image and state processing.
|
||||||
|
|
||||||
This processor replicates the functionality of the original preprocess_observation
|
This processor replicates the functionality of the original preprocess_observation
|
||||||
|
|||||||
@@ -467,7 +467,35 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
|
|
||||||
class ObservationProcessor:
|
class ObservationProcessor:
|
||||||
|
"""Base class for processors that modify only the observation component of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `observation` method to implement custom observation processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed observation
|
||||||
|
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
class MyObservationScaler(ObservationProcessor):
|
||||||
|
def __init__(self, scale_factor):
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
return observation * self.scale_factor
|
||||||
|
```
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||||
|
manipulation, focusing only on the specific observation processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
|
"""Process the observation component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: The observation to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed observation
|
||||||
|
"""
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -478,7 +506,36 @@ class ObservationProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class ActionProcessor:
|
class ActionProcessor:
|
||||||
|
"""Base class for processors that modify only the action component of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `action` method to implement custom action processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed action
|
||||||
|
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
class ActionClipping(ActionProcessor):
|
||||||
|
def __init__(self, min_val, max_val):
|
||||||
|
self.min_val = min_val
|
||||||
|
self.max_val = max_val
|
||||||
|
|
||||||
|
def action(self, action):
|
||||||
|
return np.clip(action, self.min_val, self.max_val)
|
||||||
|
```
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||||
|
manipulation, focusing only on the specific action processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
|
"""Process the action component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed action
|
||||||
|
"""
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -489,7 +546,35 @@ class ActionProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class RewardProcessor:
|
class RewardProcessor:
|
||||||
|
"""Base class for processors that modify only the reward component of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `reward` method to implement custom reward processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed reward
|
||||||
|
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
class RewardScaler(RewardProcessor):
|
||||||
|
def __init__(self, scale_factor):
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
def reward(self, reward):
|
||||||
|
return reward * self.scale_factor
|
||||||
|
```
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||||
|
manipulation, focusing only on the specific reward processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
def reward(self, reward):
|
def reward(self, reward):
|
||||||
|
"""Process the reward component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward: The reward to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed reward
|
||||||
|
"""
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -505,7 +590,40 @@ class RewardProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class DoneProcessor:
|
class DoneProcessor:
|
||||||
|
"""Base class for processors that modify only the done flag of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `done` method to implement custom done flag processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed done flag
|
||||||
|
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
class TimeoutDone(DoneProcessor):
|
||||||
|
def __init__(self, max_steps):
|
||||||
|
self.steps = 0
|
||||||
|
self.max_steps = max_steps
|
||||||
|
|
||||||
|
def done(self, done):
|
||||||
|
self.steps += 1
|
||||||
|
return done or self.steps >= self.max_steps
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.steps = 0
|
||||||
|
```
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||||
|
manipulation, focusing only on the specific done flag processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
def done(self, done):
|
def done(self, done):
|
||||||
|
"""Process the done flag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
done: The done flag to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed done flag
|
||||||
|
"""
|
||||||
return done
|
return done
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -522,7 +640,36 @@ class DoneProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class TruncatedProcessor:
|
class TruncatedProcessor:
|
||||||
|
"""Base class for processors that modify only the truncated flag of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed truncated flag
|
||||||
|
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
class EarlyTruncation(TruncatedProcessor):
|
||||||
|
def __init__(self, threshold):
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def truncated(self, truncated):
|
||||||
|
# Additional truncation condition
|
||||||
|
return truncated or some_condition > self.threshold
|
||||||
|
```
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||||
|
manipulation, focusing only on the specific truncated flag processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
def truncated(self, truncated):
|
def truncated(self, truncated):
|
||||||
|
"""Process the truncated flag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncated: The truncated flag to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed truncated flag
|
||||||
|
"""
|
||||||
return truncated
|
return truncated
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -540,7 +687,41 @@ class TruncatedProcessor:
|
|||||||
|
|
||||||
|
|
||||||
class InfoProcessor:
|
class InfoProcessor:
|
||||||
|
"""Base class for processors that modify only the info dictionary of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `info` method to implement custom info processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed info
|
||||||
|
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
class InfoAugmenter(InfoProcessor):
|
||||||
|
def __init__(self):
|
||||||
|
self.step_count = 0
|
||||||
|
|
||||||
|
def info(self, info):
|
||||||
|
info = info.copy() # Create a copy to avoid modifying the original
|
||||||
|
info['steps'] = self.step_count
|
||||||
|
self.step_count += 1
|
||||||
|
return info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.step_count = 0
|
||||||
|
```
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition tuple
|
||||||
|
manipulation, focusing only on the specific info dictionary processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
def info(self, info):
|
def info(self, info):
|
||||||
|
"""Process the info dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info: The info dictionary to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed info dictionary
|
||||||
|
"""
|
||||||
return info
|
return info
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import torch
|
|||||||
|
|
||||||
from lerobot.processor.observation_processor import (
|
from lerobot.processor.observation_processor import (
|
||||||
ImageProcessor,
|
ImageProcessor,
|
||||||
ObservationProcessor,
|
|
||||||
StateProcessor,
|
StateProcessor,
|
||||||
|
VanillaObservationProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -262,7 +262,7 @@ def test_no_states_in_observation():
|
|||||||
|
|
||||||
def test_complete_observation_processing():
|
def test_complete_observation_processing():
|
||||||
"""Test processing a complete observation with both images and states."""
|
"""Test processing a complete observation with both images and states."""
|
||||||
processor = ObservationProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Create mock data
|
# Create mock data
|
||||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||||
@@ -299,7 +299,7 @@ def test_complete_observation_processing():
|
|||||||
|
|
||||||
def test_image_only_processing():
|
def test_image_only_processing():
|
||||||
"""Test processing observation with only images."""
|
"""Test processing observation with only images."""
|
||||||
processor = ObservationProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||||
observation = {"pixels": image}
|
observation = {"pixels": image}
|
||||||
@@ -314,7 +314,7 @@ def test_image_only_processing():
|
|||||||
|
|
||||||
def test_state_only_processing():
|
def test_state_only_processing():
|
||||||
"""Test processing observation with only states."""
|
"""Test processing observation with only states."""
|
||||||
processor = ObservationProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||||
observation = {"agent_pos": agent_pos}
|
observation = {"agent_pos": agent_pos}
|
||||||
@@ -329,7 +329,7 @@ def test_state_only_processing():
|
|||||||
|
|
||||||
def test_empty_observation():
|
def test_empty_observation():
|
||||||
"""Test processing empty observation."""
|
"""Test processing empty observation."""
|
||||||
processor = ObservationProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
observation = {}
|
observation = {}
|
||||||
transition = (observation, None, None, None, None, None, None)
|
transition = (observation, None, None, None, None, None, None)
|
||||||
@@ -344,7 +344,7 @@ def test_custom_sub_processors():
|
|||||||
"""Test ObservationProcessor with custom sub-processors."""
|
"""Test ObservationProcessor with custom sub-processors."""
|
||||||
image_proc = ImageProcessor()
|
image_proc = ImageProcessor()
|
||||||
state_proc = StateProcessor()
|
state_proc = StateProcessor()
|
||||||
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
||||||
|
|
||||||
# Should use the provided processors
|
# Should use the provided processors
|
||||||
assert processor.image_processor is image_proc
|
assert processor.image_processor is image_proc
|
||||||
@@ -356,7 +356,7 @@ def test_equivalent_to_original_function():
|
|||||||
# Import the original function for comparison
|
# Import the original function for comparison
|
||||||
from lerobot.envs.utils import preprocess_observation
|
from lerobot.envs.utils import preprocess_observation
|
||||||
|
|
||||||
processor = ObservationProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Create test data similar to what the original function expects
|
# Create test data similar to what the original function expects
|
||||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||||
@@ -383,7 +383,7 @@ def test_equivalent_with_image_dict():
|
|||||||
"""Test equivalence with dictionary of images."""
|
"""Test equivalence with dictionary of images."""
|
||||||
from lerobot.envs.utils import preprocess_observation
|
from lerobot.envs.utils import preprocess_observation
|
||||||
|
|
||||||
processor = ObservationProcessor()
|
processor = VanillaObservationProcessor()
|
||||||
|
|
||||||
# Create test data with multiple cameras
|
# Create test data with multiple cameras
|
||||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||||
|
|||||||
Reference in New Issue
Block a user