diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index ff4688d24..c5aaa7001 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -36,11 +36,11 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ - from lerobot.processor.observation_processor import ObservationProcessor + from lerobot.processor.observation_processor import VanillaObservationProcessor from lerobot.processor.pipeline import RobotProcessor, TransitionIndex # Create processor with observation processor - processor = RobotProcessor([ObservationProcessor()]) + processor = RobotProcessor([VanillaObservationProcessor()]) # Create transition tuple and process transition = (observations, None, None, None, None, None, None) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index bcf49c905..f6acdee9e 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -16,19 +16,35 @@ from .normalize_processor import NormalizationProcessor from .observation_processor import ( ImageProcessor, - ObservationProcessor, StateProcessor, + VanillaObservationProcessor, +) +from .pipeline import ( + ActionProcessor, + DoneProcessor, + EnvTransition, + InfoProcessor, + ObservationProcessor, + ProcessorStep, + RewardProcessor, + RobotProcessor, + TruncatedProcessor, ) -from .pipeline import EnvTransition, ProcessorStep, RobotProcessor from .rename_processor import RenameProcessor __all__ = [ - "RobotProcessor", - "ProcessorStep", + "ActionProcessor", + "DoneProcessor", "EnvTransition", "ImageProcessor", - "StateProcessor", - "ObservationProcessor", + "InfoProcessor", "NormalizationProcessor", + "ObservationProcessor", + "ProcessorStep", "RenameProcessor", + "RewardProcessor", + "RobotProcessor", + "StateProcessor", + "TruncatedProcessor", + "VanillaObservationProcessor", ] diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 3ef6526fc..0d3718c43 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -181,7 +181,7 @@ class StateProcessor: @dataclass @ProcessorStepRegistry.register(name="observation_processor") -class ObservationProcessor: +class VanillaObservationProcessor: """Complete observation processor that combines image and state processing. This processor replicates the functionality of the original preprocess_observation diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index adbfaba19..ac900fcff 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -467,7 +467,35 @@ class RobotProcessor(ModelHubMixin): class ObservationProcessor: + """Base class for processors that modify only the observation component of a transition. + + Subclasses should override the `observation` method to implement custom observation processing. + This class handles the boilerplate of extracting and reinserting the processed observation + into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class MyObservationScaler(ObservationProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def observation(self, observation): + return observation * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition tuple + manipulation, focusing only on the specific observation processing logic. + """ + def observation(self, observation): + """Process the observation component. + + Args: + observation: The observation to process + + Returns: + The processed observation + """ return observation def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -478,7 +506,36 @@ class ObservationProcessor: class ActionProcessor: + """Base class for processors that modify only the action component of a transition. + + Subclasses should override the `action` method to implement custom action processing. + This class handles the boilerplate of extracting and reinserting the processed action + into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class ActionClipping(ActionProcessor): + def __init__(self, min_val, max_val): + self.min_val = min_val + self.max_val = max_val + + def action(self, action): + return np.clip(action, self.min_val, self.max_val) + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition tuple + manipulation, focusing only on the specific action processing logic. + """ + def action(self, action): + """Process the action component. + + Args: + action: The action to process + + Returns: + The processed action + """ return action def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -489,7 +546,35 @@ class ActionProcessor: class RewardProcessor: + """Base class for processors that modify only the reward component of a transition. + + Subclasses should override the `reward` method to implement custom reward processing. + This class handles the boilerplate of extracting and reinserting the processed reward + into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class RewardScaler(RewardProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def reward(self, reward): + return reward * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition tuple + manipulation, focusing only on the specific reward processing logic. + """ + def reward(self, reward): + """Process the reward component. + + Args: + reward: The reward to process + + Returns: + The processed reward + """ return reward def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -505,7 +590,40 @@ class RewardProcessor: class DoneProcessor: + """Base class for processors that modify only the done flag of a transition. + + Subclasses should override the `done` method to implement custom done flag processing. + This class handles the boilerplate of extracting and reinserting the processed done flag + into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class TimeoutDone(DoneProcessor): + def __init__(self, max_steps): + self.steps = 0 + self.max_steps = max_steps + + def done(self, done): + self.steps += 1 + return done or self.steps >= self.max_steps + + def reset(self): + self.steps = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition tuple + manipulation, focusing only on the specific done flag processing logic. + """ + def done(self, done): + """Process the done flag. + + Args: + done: The done flag to process + + Returns: + The processed done flag + """ return done def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -522,7 +640,36 @@ class DoneProcessor: class TruncatedProcessor: + """Base class for processors that modify only the truncated flag of a transition. + + Subclasses should override the `truncated` method to implement custom truncated flag processing. + This class handles the boilerplate of extracting and reinserting the processed truncated flag + into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class EarlyTruncation(TruncatedProcessor): + def __init__(self, threshold): + self.threshold = threshold + + def truncated(self, truncated): + # Additional truncation condition + return truncated or some_condition > self.threshold + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition tuple + manipulation, focusing only on the specific truncated flag processing logic. + """ + def truncated(self, truncated): + """Process the truncated flag. + + Args: + truncated: The truncated flag to process + + Returns: + The processed truncated flag + """ return truncated def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -540,7 +687,41 @@ class TruncatedProcessor: class InfoProcessor: + """Base class for processors that modify only the info dictionary of a transition. + + Subclasses should override the `info` method to implement custom info processing. + This class handles the boilerplate of extracting and reinserting the processed info + into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class InfoAugmenter(InfoProcessor): + def __init__(self): + self.step_count = 0 + + def info(self, info): + info = info.copy() # Create a copy to avoid modifying the original + info['steps'] = self.step_count + self.step_count += 1 + return info + + def reset(self): + self.step_count = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition tuple + manipulation, focusing only on the specific info dictionary processing logic. + """ + def info(self, info): + """Process the info dictionary. + + Args: + info: The info dictionary to process + + Returns: + The processed info dictionary + """ return info def __call__(self, transition: EnvTransition) -> EnvTransition: diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 1fffb3e34..62f5cfd30 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -20,8 +20,8 @@ import torch from lerobot.processor.observation_processor import ( ImageProcessor, - ObservationProcessor, StateProcessor, + VanillaObservationProcessor, ) @@ -262,7 +262,7 @@ def test_no_states_in_observation(): def test_complete_observation_processing(): """Test processing a complete observation with both images and states.""" - processor = ObservationProcessor() + processor = VanillaObservationProcessor() # Create mock data image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -299,7 +299,7 @@ def test_complete_observation_processing(): def test_image_only_processing(): """Test processing observation with only images.""" - processor = ObservationProcessor() + processor = VanillaObservationProcessor() image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) observation = {"pixels": image} @@ -314,7 +314,7 @@ def test_image_only_processing(): def test_state_only_processing(): """Test processing observation with only states.""" - processor = ObservationProcessor() + processor = VanillaObservationProcessor() agent_pos = np.array([1.0, 2.0], dtype=np.float32) observation = {"agent_pos": agent_pos} @@ -329,7 +329,7 @@ def test_state_only_processing(): def test_empty_observation(): """Test processing empty observation.""" - processor = ObservationProcessor() + processor = VanillaObservationProcessor() observation = {} transition = (observation, None, None, None, None, None, None) @@ -344,7 +344,7 @@ def test_custom_sub_processors(): """Test ObservationProcessor with custom sub-processors.""" image_proc = ImageProcessor() state_proc = StateProcessor() - processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc) + processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc) # Should use the provided processors assert processor.image_processor is image_proc @@ -356,7 +356,7 @@ def test_equivalent_to_original_function(): # Import the original function for comparison from lerobot.envs.utils import preprocess_observation - processor = ObservationProcessor() + processor = VanillaObservationProcessor() # Create test data similar to what the original function expects image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) @@ -383,7 +383,7 @@ def test_equivalent_with_image_dict(): """Test equivalence with dictionary of images.""" from lerobot.envs.utils import preprocess_observation - processor = ObservationProcessor() + processor = VanillaObservationProcessor() # Create test data with multiple cameras image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)