diff --git a/docs/source/processor_tutorial.mdx b/docs/source/processor_tutorial.mdx index 0233dc392..9bfa36878 100644 --- a/docs/source/processor_tutorial.mdx +++ b/docs/source/processor_tutorial.mdx @@ -112,24 +112,23 @@ RobotProcessor solves these issues by providing a declarative pipeline approach RobotProcessor works with two data formats: -### 1. EnvTransition Tuple Format +### 1. EnvTransition Dictionary Format -An `EnvTransition` is a 7-tuple that represents a complete transition in the environment: +An `EnvTransition` is a dictionary that represents a complete transition in the environment: ```python -from lerobot.processor.pipeline import TransitionIndex +from lerobot.processor.pipeline import TransitionKey # EnvTransition structure: -# (observation, action, reward, done, truncated, info, complementary_data) -transition = ( - {"observation.image": ..., "observation.state": ...}, # observation at time t - [0.1, -0.2, 0.3], # action taken at time t - 1.0, # reward received - False, # episode done flag - False, # episode truncated flag - {"success": True}, # additional info from environment - {"step_idx": 42} # complementary_data for inter-step communication -) +transition = { + TransitionKey.OBSERVATION: {"observation.image": ..., "observation.state": ...}, # observation at time t + TransitionKey.ACTION: [0.1, -0.2, 0.3], # action taken at time t + TransitionKey.REWARD: 1.0, # reward received + TransitionKey.DONE: False, # episode done flag + TransitionKey.TRUNCATED: False, # episode truncated flag + TransitionKey.INFO: {"success": True}, # additional info from environment + TransitionKey.COMPLEMENTARY_DATA: {"step_idx": 42} # complementary_data for inter-step communication +} ``` ### 2. Batch Dictionary Format @@ -160,9 +159,17 @@ from lerobot.processor.observation_processor import ImageProcessor processor = RobotProcessor([ImageProcessor()]) -# Works with EnvTransition tuples -transition = ({"pixels": image_array}, None, 0.0, False, False, {}, {}) -processed_transition = processor(transition) # Returns EnvTransition tuple +# Works with EnvTransition dictionaries +transition = { + TransitionKey.OBSERVATION: {"pixels": image_array}, + TransitionKey.ACTION: None, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {} +} +processed_transition = processor(transition) # Returns EnvTransition dictionary # Also works with batch dictionaries batch = { @@ -176,25 +183,25 @@ batch = { processed_batch = processor(batch) # Returns batch dictionary ``` -### Using TransitionIndex +### Using TransitionKey -Instead of using magic numbers to access tuple elements, use the `TransitionIndex` enum: +Use the `TransitionKey` enum to access dictionary elements: ```python -from lerobot.processor.pipeline import TransitionIndex +from lerobot.processor.pipeline import TransitionKey -# Bad - using magic numbers -obs = transition[0] -action = transition[1] +# Good - using TransitionKey +obs = transition[TransitionKey.OBSERVATION] +action = transition[TransitionKey.ACTION] +reward = transition[TransitionKey.REWARD] +done = transition[TransitionKey.DONE] +truncated = transition[TransitionKey.TRUNCATED] +info = transition[TransitionKey.INFO] +comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] -# Good - using TransitionIndex -obs = transition[TransitionIndex.OBSERVATION] -action = transition[TransitionIndex.ACTION] -reward = transition[TransitionIndex.REWARD] -done = transition[TransitionIndex.DONE] -truncated = transition[TransitionIndex.TRUNCATED] -info = transition[TransitionIndex.INFO] -comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] +# Alternative - using .get() for optional access +obs = transition.get(TransitionKey.OBSERVATION) +action = transition.get(TransitionKey.ACTION) ``` ### Default Conversion Functions @@ -203,43 +210,49 @@ RobotProcessor uses these default conversion functions: ```python def _default_batch_to_transition(batch): - """Default conversion from batch dict to EnvTransition tuple.""" + """Default conversion from batch dict to EnvTransition dictionary.""" # Extract observation keys (anything starting with "observation.") observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + observation = observation_keys if observation_keys else None - observation = None - if observation_keys: - observation = {} - # Keep observation.* keys as-is (don't remove "observation." prefix) - for key, value in observation_keys.items(): - observation[key] = value + # Extract padding and task keys for complementary data + pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} + task_key = {"task": batch["task"]} if "task" in batch else {} + complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} - return ( - observation, - batch.get("action"), - batch.get("next.reward", 0.0), # Note: "next.reward" not "reward" - batch.get("next.done", False), # Note: "next.done" not "done" - batch.get("next.truncated", False), # Note: "next.truncated" not "truncated" - batch.get("info", {}), - {}, # Empty complementary_data - ) + transition = { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: batch.get("action"), + TransitionKey.REWARD: batch.get("next.reward", 0.0), + TransitionKey.DONE: batch.get("next.done", False), + TransitionKey.TRUNCATED: batch.get("next.truncated", False), + TransitionKey.INFO: batch.get("info", {}), + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + return transition def _default_transition_to_batch(transition): - """Default conversion from EnvTransition tuple to batch dict.""" - obs, action, reward, done, truncated, info, _ = transition - + """Default conversion from EnvTransition dictionary to batch dict.""" batch = { - "action": action, - "next.reward": reward, # Note: "next.reward" not "reward" - "next.done": done, # Note: "next.done" not "done" - "next.truncated": truncated, # Note: "next.truncated" not "truncated" - "info": info, + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), } - # Flatten observation dict (keep observation.* keys as-is) - if isinstance(obs, dict): - for key, value in obs.items(): - batch[key] = value + # Add padding and task data from complementary_data + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data: + pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} + batch.update(pad_data) + if "task" in complementary_data: + batch["task"] = complementary_data["task"] + + # Handle observation - flatten dict to observation.* keys if it's a dict + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict): + batch.update(observation) return batch ``` @@ -250,33 +263,32 @@ You can customize how RobotProcessor converts between formats: ```python def custom_batch_to_transition(batch): - """Custom conversion from batch dict to EnvTransition tuple.""" + """Custom conversion from batch dict to EnvTransition dictionary.""" # Extract observation keys (anything starting with "observation.") observation = {k: v for k, v in batch.items() if k.startswith("observation.")} - return ( - observation, - batch.get("action"), - batch.get("reward", 0.0), # Use "reward" instead of "next.reward" - batch.get("done", False), # Use "done" instead of "next.done" - batch.get("truncated", False), - batch.get("info", {}), - batch.get("complementary_data", {}) - ) + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: batch.get("action"), + TransitionKey.REWARD: batch.get("reward", 0.0), # Use "reward" instead of "next.reward" + TransitionKey.DONE: batch.get("done", False), # Use "done" instead of "next.done" + TransitionKey.TRUNCATED: batch.get("truncated", False), + TransitionKey.INFO: batch.get("info", {}), + TransitionKey.COMPLEMENTARY_DATA: batch.get("complementary_data", {}) + } def custom_transition_to_batch(transition): - """Custom conversion from EnvTransition tuple to batch dict.""" - obs, action, reward, done, truncated, info, comp_data = transition - + """Custom conversion from EnvTransition dictionary to batch dict.""" batch = { - "action": action, - "reward": reward, # Use "reward" instead of "next.reward" - "done": done, # Use "done" instead of "next.done" - "truncated": truncated, - "info": info, + "action": transition.get(TransitionKey.ACTION), + "reward": transition.get(TransitionKey.REWARD), # Use "reward" instead of "next.reward" + "done": transition.get(TransitionKey.DONE), # Use "done" instead of "next.done" + "truncated": transition.get(TransitionKey.TRUNCATED), + "info": transition.get(TransitionKey.INFO), } # Flatten observation dict + obs = transition.get(TransitionKey.OBSERVATION) if obs: batch.update(obs) @@ -292,21 +304,21 @@ processor = RobotProcessor( ### Advanced: Controlling Output Format with `to_output` -The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` tuples even when you input batch dictionaries: +The `to_output` function determines what format is returned when you call the processor with a batch dictionary. Sometimes you want to output `EnvTransition` dictionaries even when you input batch dictionaries: ```python -# Identity function to always return EnvTransition tuples +# Identity function to always return EnvTransition dictionaries def keep_as_transition(transition): - """Always return EnvTransition tuple regardless of input format.""" + """Always return EnvTransition dictionary regardless of input format.""" return transition -# Processor that always outputs EnvTransition tuples +# Processor that always outputs EnvTransition dictionaries processor = RobotProcessor( steps=[ImageProcessor(), StateProcessor()], - to_output=keep_as_transition # Always return tuple format + to_output=keep_as_transition # Always return dictionary format ) -# Even when called with batch dict, returns EnvTransition tuple +# Even when called with batch dict, returns EnvTransition dictionary batch = { "observation.image": image_tensor, "action": action_tensor, @@ -316,13 +328,13 @@ batch = { "info": info_dict } -result = processor(batch) # Returns EnvTransition tuple, not batch dict! -print(type(result)) # +result = processor(batch) # Returns EnvTransition dictionary, not batch dict! +print(type(result)) # ``` ### Real-World Example: Environment Interaction -This is particularly useful for environment interaction where you want consistent tuple output: +This is particularly useful for environment interaction where you want consistent dictionary output: ```python from lerobot.processor.observation_processor import VanillaObservationProcessor @@ -332,7 +344,7 @@ from lerobot.processor.observation_processor import VanillaObservationProcessor env_processor = RobotProcessor( [VanillaObservationProcessor()], to_transition=lambda x: x, # Pass through - no conversion needed - to_output=lambda x: x, # Always return EnvTransition tuple + to_output=lambda x: x, # Always return EnvTransition dictionary ) # Environment interaction loop @@ -340,12 +352,20 @@ env = make_env() obs, info = env.reset() for step in range(1000): - # Create transition - input is already in tuple format - transition = (obs, None, 0.0, False, False, info, {"step": step}) + # Create transition - input is already in dictionary format + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: None, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: {"step": step} + } - # Process - output is guaranteed to be EnvTransition tuple + # Process - output is guaranteed to be EnvTransition dictionary processed_transition = env_processor(transition) - processed_obs = processed_transition[TransitionIndex.OBSERVATION] + processed_obs = processed_transition[TransitionKey.OBSERVATION] # Use with policy action = policy.select_action(processed_obs) @@ -357,12 +377,12 @@ for step in range(1000): ### When to Use Different Output Formats -**Use EnvTransition tuple output when:** +**Use EnvTransition dictionary output when:** - Environment interaction and real-time control - You need to access individual transition components frequently -- Performance is critical (avoids dictionary creation overhead) -- Working with gym environments that expect tuple format +- Working with gym environments that expect structured data +- You need the flexibility of dictionary operations **Use batch dictionary output when:** @@ -372,10 +392,10 @@ for step in range(1000): - You need the standardized "next.\*" key format ```python -# For environment interaction - use tuple output +# For environment interaction - use dictionary output env_processor = RobotProcessor( steps=[ImageProcessor(), StateProcessor()], - to_output=lambda x: x # Return EnvTransition tuple + to_output=lambda x: x # Return EnvTransition dictionary ) # For training - use batch output (default) @@ -391,9 +411,17 @@ for batch in dataloader: # Environment loop for step in range(1000): - transition = (obs, None, 0.0, False, False, info, {}) - processed_transition = env_processor(transition) # Returns EnvTransition tuple - obs = processed_transition[TransitionIndex.OBSERVATION] + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: None, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: {} + } + processed_transition = env_processor(transition) # Returns EnvTransition dictionary + obs = processed_transition[TransitionKey.OBSERVATION] action = policy.select_action(obs) ``` @@ -426,7 +454,7 @@ batch = { Let's create a processor that properly handles image and state preprocessing: ```python -from lerobot.processor.pipeline import RobotProcessor, TransitionIndex +from lerobot.processor.pipeline import RobotProcessor, TransitionKey from lerobot.processor.observation_processor import ImageProcessor, StateProcessor import numpy as np @@ -440,7 +468,15 @@ observation = { } # Create a full transition -transition = (observation, None, 0.0, False, False, {}, {}) +transition = { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: None, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {} +} # Create and use the processor processor = RobotProcessor([ @@ -449,7 +485,7 @@ processor = RobotProcessor([ ]) processed_transition = processor(transition) -processed_obs = processed_transition[TransitionIndex.OBSERVATION] +processed_obs = processed_transition[TransitionKey.OBSERVATION] # Check the results print("Original keys:", observation.keys()) @@ -541,11 +577,19 @@ obs, info = env.reset() for step in range(1000): # Raw environment observation - transition = (obs, None, 0.0, False, False, info, {}) + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: None, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: {} + } # Process for policy input processed_transition = online_processor(transition) - processed_obs = processed_transition[TransitionIndex.OBSERVATION] + processed_obs = processed_transition[TransitionKey.OBSERVATION] # Get action from policy action = policy.select_action(processed_obs) @@ -585,15 +629,16 @@ class ImagePadder: def __call__(self, transition: EnvTransition) -> EnvTransition: """Main processing method - required for all steps.""" - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if obs is None: return transition # Process all image observations - for key in list(obs.keys()): + processed_obs = dict(obs) # Create a copy + for key in list(processed_obs.keys()): if key.startswith("observation.images."): - img = obs[key] + img = processed_obs[key] # Calculate padding _, _, h, w = img.shape pad_h = max(0, self.target_height - h) @@ -609,10 +654,12 @@ class ImagePadder: img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=self.pad_value) - obs[key] = img + processed_obs[key] = img # Return modified transition - return (obs, *transition[1:]) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition def get_config(self) -> Dict[str, Any]: """Return JSON-serializable configuration - required for save/load.""" @@ -694,8 +741,8 @@ class ImageStatisticsCalculator: """Calculate image statistics and pass to next steps.""" def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] - comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {} + obs = transition.get(TransitionKey.OBSERVATION) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} if obs is None: return transition @@ -714,18 +761,13 @@ class ImageStatisticsCalculator: image_stats[key] = stats # Store in complementary_data for next steps + comp_data = dict(comp_data) # Make a copy comp_data["image_statistics"] = image_stats # Return transition with updated complementary_data - return ( - obs, - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - comp_data # Updated complementary_data - ) + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition @dataclass class AdaptiveBrightnessAdjuster: @@ -734,8 +776,8 @@ class AdaptiveBrightnessAdjuster: target_brightness: float = 0.5 def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] - comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {} + obs = transition.get(TransitionKey.OBSERVATION) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} if obs is None or "image_statistics" not in comp_data: return transition @@ -743,15 +785,18 @@ class AdaptiveBrightnessAdjuster: # Use statistics from previous step image_stats = comp_data["image_statistics"] - for key in obs: + processed_obs = dict(obs) # Create a copy + for key in processed_obs: if key.startswith("observation.images.") and key in image_stats: current_mean = image_stats[key]["mean"] brightness_adjust = self.target_brightness - current_mean # Adjust brightness - obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1) + processed_obs[key] = torch.clamp(processed_obs[key] + brightness_adjust, 0, 1) - return (obs, *transition[1:]) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition # Use them together processor = RobotProcessor([ @@ -782,7 +827,7 @@ class ActionRepeatStep: env: gym.Env = None # This can't be serialized to JSON! def __call__(self, transition: EnvTransition) -> EnvTransition: - obs, action, reward, done, truncated, info, comp_data = transition + action = transition.get(TransitionKey.ACTION) if self.env is not None and action is not None: # Repeat action multiple times in environment @@ -792,9 +837,13 @@ class ActionRepeatStep: total_reward += r if d or t: break - reward = total_reward - return (obs, action, reward, done, truncated, info, comp_data) + # Update reward in transition + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = total_reward + return new_transition + + return transition def get_config(self) -> dict[str, Any]: # Note: env is NOT included because it's not serializable @@ -1211,7 +1260,7 @@ This enables sharing of preprocessing logic while allowing each user to provide Here's a complete example showing proper device management and all features: ```python -from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex +from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey import torch import torch.nn.functional as F import numpy as np @@ -1224,23 +1273,29 @@ class DeviceMover: device: str = "cuda" def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if obs is None: return transition # Move all tensor observations to device - for key, value in obs.items(): + processed_obs = dict(obs) # Create a copy + for key, value in processed_obs.items(): if isinstance(value, torch.Tensor): - obs[key] = value.to(self.device) + processed_obs[key] = value.to(self.device) # Also handle action if present - action = transition[TransitionIndex.ACTION] + action = transition.get(TransitionKey.ACTION) if action is not None and isinstance(action, torch.Tensor): action = action.to(self.device) - return (obs, action, *transition[2:]) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + new_transition[TransitionKey.ACTION] = action + return new_transition - return (obs, *transition[1:]) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition def get_config(self) -> Dict[str, Any]: return {"device": str(self.device)} @@ -1260,7 +1315,7 @@ class RunningNormalizer: self.initialized = False def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if obs is None or "observation.state" not in obs: return transition @@ -1284,9 +1339,12 @@ class RunningNormalizer: # Normalize state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt() - obs["observation.state"] = state_normalized + processed_obs = dict(obs) # Create a copy + processed_obs["observation.state"] = state_normalized - return (obs, *transition[1:]) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition def get_config(self) -> Dict[str, Any]: return { @@ -1317,7 +1375,7 @@ class RunningNormalizer: processor = RobotProcessor([ ImageProcessor(), # Convert images to float32 [0,1] StateProcessor(), # Convert states to torch tensors - ImagePadder(224, 224), # Pad images to standard size + ImagePadder(target_height=224, target_width=224), # Pad images to standard size DeviceMover("cuda"), # Move everything to GPU RunningNormalizer(7), # Normalize states ], name="CompletePreprocessor") @@ -1330,11 +1388,19 @@ obs = { "pixels": {"cam": np.random.randint(0, 255, (200, 300, 3), dtype=np.uint8)}, "agent_pos": np.random.randn(7).astype(np.float32) } -transition = (obs, None, 0.0, False, False, {}, {}) +transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: None, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {} +} # Everything is processed and on GPU processed = processor(transition) -print(processed[TransitionIndex.OBSERVATION]["observation.images.cam"].device) # cuda:0 +print(processed[TransitionKey.OBSERVATION]["observation.images.cam"].device) # cuda:0 ``` ## Solving Real-World Problems with RobotProcessor @@ -1358,22 +1424,25 @@ class KeyRemapper: }) def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if obs is None: return transition # Create new observation with renamed keys + processed_obs = dict(obs) # Create a copy renamed_obs = {} for old_key, new_key in self.key_mapping.items(): - if old_key in obs: - renamed_obs[new_key] = obs[old_key] + if old_key in processed_obs: + renamed_obs[new_key] = processed_obs[old_key] # Keep any unmapped keys as-is - for key, value in obs.items(): + for key, value in processed_obs.items(): if key not in self.key_mapping: renamed_obs[key] = value - return (renamed_obs, *transition[1:]) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = renamed_obs + return new_transition ``` ### Workspace-Focused Image Processing @@ -1390,13 +1459,14 @@ class WorkspaceCropper: output_size: Tuple[int, int] = (224, 224) def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if obs is None: return transition - for key in list(obs.keys()): + processed_obs = dict(obs) # Create a copy + for key in list(processed_obs.keys()): if key.startswith("observation.images."): - img = obs[key] + img = processed_obs[key] # Crop to workspace x1, y1, x2, y2 = self.crop_bbox img_cropped = img[:, :, y1:y2, x1:x2] @@ -1407,9 +1477,11 @@ class WorkspaceCropper: mode='bilinear', align_corners=False ) - obs[key] = img_resized + processed_obs[key] = img_resized - return (obs, *transition[1:]) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition ``` ### Building Complete Pipelines for Different Robots @@ -1471,7 +1543,7 @@ The beauty of this approach is that: ```python def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) # Always check if observation exists if obs is None: @@ -1496,7 +1568,7 @@ return (modified_obs, None, 0.0, False, False, {}, {}) ```python def __call__(self, transition: EnvTransition) -> EnvTransition: - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if self.store_previous: # Good - clone to avoid reference issues @@ -1522,7 +1594,7 @@ def state_dict(self) -> Dict[str, torch.Tensor]: Here's how to use RobotProcessor in a real robot control loop, showing both tuple and batch formats: ```python -from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionIndex +from lerobot.processor.pipeline import RobotProcessor, ProcessorStepRegistry, TransitionKey from lerobot.policies.act.modeling_act import ACTPolicy from pathlib import Path import time @@ -1545,11 +1617,13 @@ class ActionClipper: max_value: float = 1.0 def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition[TransitionIndex.ACTION] + action = transition.get(TransitionKey.ACTION) if action is not None: action = torch.clamp(action, self.min_value, self.max_value) - return (transition[TransitionIndex.OBSERVATION], action, *transition[2:]) + new_transition = transition.copy() + new_transition[TransitionKey.ACTION] = action + return new_transition return transition @@ -1578,28 +1652,36 @@ for episode in range(10): for step in range(1000): # Create transition with raw observation - transition = (obs, None, 0.0, False, False, info, {"step": step}) + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: None, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: {"step": step} + } - # Preprocess - works with tuple format + # Preprocess - works with dictionary format processed_transition = preprocessor(transition) - processed_obs = processed_transition[TransitionIndex.OBSERVATION] + processed_obs = processed_transition.get(TransitionKey.OBSERVATION) # Get action from policy with torch.no_grad(): action = policy.select_action(processed_obs) # Postprocess action - action_transition = ( - processed_obs, - action, - 0.0, - False, - False, - info, - {"raw_action": action.clone()} # Store raw action in complementary_data - ) + action_transition = { + TransitionKey.OBSERVATION: processed_obs, + TransitionKey.ACTION: action, + TransitionKey.REWARD: 0.0, + TransitionKey.DONE: False, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: {"raw_action": action.clone()} # Store raw action in complementary_data + } processed_action_transition = postprocessor(action_transition) - final_action = processed_action_transition[TransitionIndex.ACTION] + final_action = processed_action_transition.get(TransitionKey.ACTION) # Execute action obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy()) @@ -1667,7 +1749,7 @@ Use the full power of `RobotProcessor` for debugging: ```python # Enable detailed logging def log_observation_shapes(step_idx: int, transition: EnvTransition): - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if obs: print(f"Step {step_idx} observations:") for key, value in obs.items(): @@ -1679,7 +1761,7 @@ processor.register_after_step_hook(log_observation_shapes) # Monitor complementary data flow def monitor_complementary_data(step_idx: int, transition: EnvTransition): - comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) if comp_data: print(f"Step {step_idx} complementary_data: {list(comp_data.keys())}") return None @@ -1688,7 +1770,7 @@ processor.register_before_step_hook(monitor_complementary_data) # Validate data integrity def validate_tensors(step_idx: int, transition: EnvTransition): - obs = transition[TransitionIndex.OBSERVATION] + obs = transition.get(TransitionKey.OBSERVATION) if obs: for key, value in obs.items(): if isinstance(value, torch.Tensor): @@ -1705,21 +1787,21 @@ processor.register_after_step_hook(validate_tensors) RobotProcessor provides a powerful, modular approach to data preprocessing in robotics: -- **Dual format support**: Works seamlessly with both EnvTransition tuples and batch dictionaries -- **Automatic format conversion**: Converts between tuple and batch formats as needed +- **Dual format support**: Works seamlessly with both EnvTransition dictionaries and batch dictionaries +- **Automatic format conversion**: Converts between dictionary and batch formats as needed - **LeRobot integration**: Native support for LeRobotDataset and ReplayBuffer formats - **Clear separation of concerns**: Each transformation is a separate, testable unit - **Proper state management**: Clear distinction between config (JSON) and state (tensors) - **Device-aware**: Seamless GPU/CPU transfers with `.to(device)` - **Inter-step communication**: Use `complementary_data` for passing information - **Easy sharing**: Push to Hugging Face Hub for reproducibility -- **Type safety**: Use `TransitionIndex` instead of magic numbers +- **Type safety**: Use `TransitionKey` instead of magic numbers - **Debugging tools**: Step through transformations and add monitoring hooks - **Flexible conversion**: Customize `to_transition` and `to_output` functions for specific needs Key advantages of the dual format approach: -- **Environment interaction**: Use tuple format for real-time robot control +- **Environment interaction**: Use dictionary format for real-time robot control - **Training/evaluation**: Use batch format for dataset processing and model training - **Seamless integration**: Same processor works with both formats automatically - **Backward compatibility**: Existing code using either format continues to work diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index a65023d32..2fb85ed20 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -36,17 +36,25 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ - from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor + from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor # Create processor with observation processor processor = RobotProcessor([VanillaObservationProcessor()]) - # Create transition tuple and process - transition = (observations, None, None, None, None, None, None) - processed_transition = processor(transition) + # Create transition dictionary and process + transition = { + TransitionKey.OBSERVATION: observations, + TransitionKey.ACTION: None, + TransitionKey.REWARD: None, + TransitionKey.DONE: None, + TransitionKey.TRUNCATED: None, + TransitionKey.INFO: None, + TransitionKey.COMPLEMENTARY_DATA: None, + } + result = processor(transition) - # Return processed observations - return processed_transition[TransitionIndex.OBSERVATION] + # Extract and return the processed observation + return result[TransitionKey.OBSERVATION] def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 5dd2e0125..0a5a5dd2c 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -32,7 +32,7 @@ from .pipeline import ( ProcessorStepRegistry, RewardProcessor, RobotProcessor, - TransitionIndex, + TransitionKey, TruncatedProcessor, ) from .rename_processor import RenameProcessor @@ -54,7 +54,7 @@ __all__ = [ "RewardProcessor", "RobotProcessor", "StateProcessor", - "TransitionIndex", + "TransitionKey", "TruncatedProcessor", "VanillaObservationProcessor", ] diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index feb5eb72b..232454850 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -18,7 +18,7 @@ from typing import Any import torch -from lerobot.processor.pipeline import EnvTransition, TransitionIndex +from lerobot.processor.pipeline import EnvTransition, TransitionKey @dataclass @@ -35,30 +35,41 @@ class DeviceProcessor: self.non_blocking = "cuda" in self.device def __call__(self, transition: EnvTransition) -> EnvTransition: - observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION] - action = transition[TransitionIndex.ACTION] - reward = transition[TransitionIndex.REWARD] - done = transition[TransitionIndex.DONE] - truncated = transition[TransitionIndex.TRUNCATED] - info = transition[TransitionIndex.INFO] - complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA] + # Create a copy of the transition + new_transition = transition.copy() + # Process observation tensors + observation = transition.get(TransitionKey.OBSERVATION) if observation is not None: - observation = { - k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items() + new_observation = { + k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in observation.items() } - if action is not None: - action = action.to(self.device) + new_transition[TransitionKey.OBSERVATION] = new_observation - return ( - observation, - action, - reward, - done, - truncated, - info, - complementary_data, - ) + # Process action tensor + action = transition.get(TransitionKey.ACTION) + if action is not None and isinstance(action, torch.Tensor): + new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking) + + # Process reward tensor + reward = transition.get(TransitionKey.REWARD) + if reward is not None and isinstance(reward, torch.Tensor): + new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking) + + # Process done tensor + done = transition.get(TransitionKey.DONE) + if done is not None and isinstance(done, torch.Tensor): + new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking) + + # Process truncated tensor + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is not None and isinstance(truncated, torch.Tensor): + new_transition[TransitionKey.TRUNCATED] = truncated.to( + self.device, non_blocking=self.non_blocking + ) + + return new_transition def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index e037b3c8c..70c4f764f 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -1,8 +1,8 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass, field from typing import Any -from collections.abc import Mapping import numpy as np import torch @@ -10,7 +10,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: @@ -166,17 +166,14 @@ class NormalizerProcessor: raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._normalize_obs(transition[TransitionIndex.OBSERVATION]) - action = self._normalize_action(transition[TransitionIndex.ACTION]) - return ( - observation, - action, - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - transition[TransitionIndex.COMPLEMENTARY_DATA], - ) + observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._normalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with normalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition def get_config(self) -> dict[str, Any]: config = { @@ -297,17 +294,14 @@ class UnnormalizerProcessor: raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._unnormalize_obs(transition[TransitionIndex.OBSERVATION]) - action = self._unnormalize_action(transition[TransitionIndex.ACTION]) - return ( - observation, - action, - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - transition[TransitionIndex.COMPLEMENTARY_DATA], - ) + observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._unnormalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with unnormalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition def get_config(self) -> dict[str, Any]: return { diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index c2c240d32..244bee241 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -21,7 +21,7 @@ import numpy as np import torch from torch import Tensor -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey @dataclass @@ -36,7 +36,7 @@ class ImageProcessor: """ def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition[TransitionIndex.OBSERVATION] + observation = transition.get(TransitionKey.OBSERVATION) if observation is None: return transition @@ -60,15 +60,9 @@ class ImageProcessor: processed_obs[key] = value # Return new transition with processed observation - return ( - processed_obs, - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - transition[TransitionIndex.COMPLEMENTARY_DATA], - ) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition def _process_single_image(self, img: np.ndarray) -> Tensor: """Process a single image array.""" @@ -124,7 +118,7 @@ class StateProcessor: """ def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition[TransitionIndex.OBSERVATION] + observation = transition.get(TransitionKey.OBSERVATION) if observation is None: return transition @@ -150,15 +144,9 @@ class StateProcessor: del processed_obs["agent_pos"] # Return new transition with processed observation - return ( - processed_obs, - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - transition[TransitionIndex.COMPLEMENTARY_DATA], - ) + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index e6c438781..323a6066c 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -18,39 +18,42 @@ from __future__ import annotations import importlib import json import os -from dataclasses import dataclass, field -from enum import IntEnum -from pathlib import Path -from typing import Any, Protocol from collections.abc import Callable, Iterable, Sequence +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Protocol, TypedDict import torch from huggingface_hub import ModelHubMixin, hf_hub_download from safetensors.torch import load_file, save_file -class TransitionIndex(IntEnum): - """Explicit indices for EnvTransition tuple components.""" +class TransitionKey(str, Enum): + """Keys for accessing EnvTransition dictionary components.""" - OBSERVATION = 0 - ACTION = 1 - REWARD = 2 - DONE = 3 - TRUNCATED = 4 - INFO = 5 - COMPLEMENTARY_DATA = 6 + OBSERVATION = "observation" + ACTION = "action" + REWARD = "reward" + DONE = "done" + TRUNCATED = "truncated" + INFO = "info" + COMPLEMENTARY_DATA = "complementary_data" -# (observation, action, reward, done, truncated, info, complementary_data) -EnvTransition = tuple[ - dict[str, Any] | None, # observation - Any | torch.Tensor | None, # action - float | torch.Tensor | None, # reward - bool | torch.Tensor | None, # done - bool | torch.Tensor | None, # truncated - dict[str, Any] | None, # info - dict[str, Any] | None, # complementary_data -] +class EnvTransition(TypedDict, total=False): + """Environment transition data structure. + + All fields are optional (total=False) to allow flexible usage. + """ + + observation: dict[str, Any] | None + action: Any | torch.Tensor | None + reward: float | torch.Tensor | None + done: bool | torch.Tensor | None + truncated: bool | torch.Tensor | None + info: dict[str, Any] | None + complementary_data: dict[str, Any] | None class ProcessorStepRegistry: @@ -165,10 +168,9 @@ class ProcessorStep(Protocol): def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 """Convert a *batch* dict coming from Learobot replay/dataset code into an - ``EnvTransition`` tuple. + ``EnvTransition`` dictionary. - The function is intentionally **strictly positional** – it maps well known - keys to the fixed slot order used inside the pipeline. Missing keys are + The function maps well known keys to the EnvTransition structure. Missing keys are filled with sane defaults (``None`` or ``0.0``/``False``). Keys recognised (case-sensitive): @@ -193,15 +195,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq task_key = {"task": batch["task"]} if "task" in batch else {} complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} - return ( - observation, - batch.get("action"), - batch.get("next.reward", 0.0), - batch.get("next.done", False), - batch.get("next.truncated", False), - batch.get("info", {}), - complementary_data, - ) + transition: EnvTransition = { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: batch.get("action"), + TransitionKey.REWARD: batch.get("next.reward", 0.0), + TransitionKey.DONE: batch.get("next.done", False), + TransitionKey.TRUNCATED: batch.get("next.truncated", False), + TransitionKey.INFO: batch.get("info", {}), + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + return transition def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 @@ -209,25 +212,16 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: the canonical field names used throughout *LeRobot*. """ - ( - observation, - action, - reward, - done, - truncated, - info, - complementary_data, - ) = transition - batch = { - "action": action, - "next.reward": reward, - "next.done": done, - "next.truncated": truncated, - "info": info, + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), } # Add padding and task data from complementary_data + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) if complementary_data: pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} batch.update(pad_data) @@ -236,6 +230,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: batch["task"] = complementary_data["task"] # Handle observation - flatten dict to observation.* keys if it's a dict + observation = transition.get(TransitionKey.OBSERVATION) if isinstance(observation, dict): batch.update(observation) @@ -293,33 +288,35 @@ class RobotProcessor(ModelHubMixin): def __call__(self, data: EnvTransition | dict[str, Any]): """Process data through all steps. - The method accepts either the classic EnvTransition tuple or a batch dictionary + The method accepts either the classic EnvTransition dict or a batch dictionary (like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied - it is first converted to the internal tuple format using to_transition; after all - steps are executed the tuple is transformed back into a dict with to_batch and the + it is first converted to the internal dict format using to_transition; after all + steps are executed the dict is transformed back into a batch dict with to_batch and the result is returned – thereby preserving the caller's original data type. Args: - data: Either an EnvTransition tuple or a batch dictionary to process. + data: Either an EnvTransition dict or a batch dictionary to process. Returns: - The processed data in the same format as the input (tuple or dict). + The processed data in the same format as the input (EnvTransition or batch dict). Raises: - ValueError: If the transition is not a valid 7-tuple format. + ValueError: If the transition is not a valid EnvTransition format. """ - called_with_batch = isinstance(data, dict) + # Check if data is already an EnvTransition or needs conversion + if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): + # It's a batch dict, convert it + called_with_batch = True + transition = self.to_transition(data) + else: + # It's already an EnvTransition + called_with_batch = False + transition = data - transition = self.to_transition(data) if called_with_batch else data - - # Basic validation with helpful error message for tuple input - if not isinstance(transition, tuple) or len(transition) != 7: - raise ValueError( - "EnvTransition must be a 7-tuple of (observation, action, reward, done, " - "truncated, info, complementary_data). " - f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}." - ) + # Basic validation + if not isinstance(transition, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") for idx, processor_step in enumerate(self.steps): for hook in self.before_step_hooks: @@ -339,25 +336,28 @@ class RobotProcessor(ModelHubMixin): def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]: """Yield the intermediate results after each processor step. - Like __call__, this method accepts either EnvTransition tuples or batch dictionaries + Like __call__, this method accepts either EnvTransition dicts or batch dictionaries and preserves the input format in the yielded results. Args: - data: Either an EnvTransition tuple or a batch dictionary to process. + data: Either an EnvTransition dict or a batch dictionary to process. Yields: The intermediate results after each step, in the same format as the input. """ - called_with_batch = isinstance(data, dict) - transition = self.to_transition(data) if called_with_batch else data + # Check if data is already an EnvTransition or needs conversion + if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): + # It's a batch dict, convert it + called_with_batch = True + transition = self.to_transition(data) + else: + # It's already an EnvTransition + called_with_batch = False + transition = data - # Basic validation with helpful error message for tuple input - if not isinstance(transition, tuple) or len(transition) != 7: - raise ValueError( - "EnvTransition must be a 7-tuple of (observation, action, reward, done, " - "truncated, info, complementary_data). " - f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}." - ) + # Basic validation + if not isinstance(transition, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") # Yield initial state yield self.to_output(transition) if called_with_batch else transition @@ -684,7 +684,7 @@ class ObservationProcessor: Subclasses should override the `observation` method to implement custom observation processing. This class handles the boilerplate of extracting and reinserting the processed observation - into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. Example: ```python @@ -696,7 +696,7 @@ class ObservationProcessor: return observation * self.scale_factor ``` - By inheriting from this class, you avoid writing repetitive code to handle transition tuple + By inheriting from this class, you avoid writing repetitive code to handle transition dict manipulation, focusing only on the specific observation processing logic. """ @@ -712,10 +712,12 @@ class ObservationProcessor: return observation def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition[TransitionIndex.OBSERVATION] - observation = self.observation(observation) - transition = (observation, *transition[TransitionIndex.ACTION :]) - return transition + observation = transition.get(TransitionKey.OBSERVATION) + processed_observation = self.observation(observation) + # Create a new transition dict with the processed observation + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_observation + return new_transition class ActionProcessor: @@ -723,7 +725,7 @@ class ActionProcessor: Subclasses should override the `action` method to implement custom action processing. This class handles the boilerplate of extracting and reinserting the processed action - into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. Example: ```python @@ -736,7 +738,7 @@ class ActionProcessor: return np.clip(action, self.min_val, self.max_val) ``` - By inheriting from this class, you avoid writing repetitive code to handle transition tuple + By inheriting from this class, you avoid writing repetitive code to handle transition dict manipulation, focusing only on the specific action processing logic. """ @@ -752,10 +754,12 @@ class ActionProcessor: return action def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition[TransitionIndex.ACTION] - action = self.action(action) - transition = (transition[TransitionIndex.OBSERVATION], action, *transition[TransitionIndex.REWARD :]) - return transition + action = transition.get(TransitionKey.ACTION) + processed_action = self.action(action) + # Create a new transition dict with the processed action + new_transition = transition.copy() + new_transition[TransitionKey.ACTION] = processed_action + return new_transition class RewardProcessor: @@ -763,7 +767,7 @@ class RewardProcessor: Subclasses should override the `reward` method to implement custom reward processing. This class handles the boilerplate of extracting and reinserting the processed reward - into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. Example: ```python @@ -775,7 +779,7 @@ class RewardProcessor: return reward * self.scale_factor ``` - By inheriting from this class, you avoid writing repetitive code to handle transition tuple + By inheriting from this class, you avoid writing repetitive code to handle transition dict manipulation, focusing only on the specific reward processing logic. """ @@ -791,15 +795,12 @@ class RewardProcessor: return reward def __call__(self, transition: EnvTransition) -> EnvTransition: - reward = transition[TransitionIndex.REWARD] - reward = self.reward(reward) - transition = ( - transition[TransitionIndex.OBSERVATION], - transition[TransitionIndex.ACTION], - reward, - *transition[TransitionIndex.DONE :], - ) - return transition + reward = transition.get(TransitionKey.REWARD) + processed_reward = self.reward(reward) + # Create a new transition dict with the processed reward + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = processed_reward + return new_transition class DoneProcessor: @@ -807,7 +808,7 @@ class DoneProcessor: Subclasses should override the `done` method to implement custom done flag processing. This class handles the boilerplate of extracting and reinserting the processed done flag - into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. Example: ```python @@ -824,7 +825,7 @@ class DoneProcessor: self.steps = 0 ``` - By inheriting from this class, you avoid writing repetitive code to handle transition tuple + By inheriting from this class, you avoid writing repetitive code to handle transition dict manipulation, focusing only on the specific done flag processing logic. """ @@ -840,16 +841,12 @@ class DoneProcessor: return done def __call__(self, transition: EnvTransition) -> EnvTransition: - done = transition[TransitionIndex.DONE] - done = self.done(done) - transition = ( - transition[TransitionIndex.OBSERVATION], - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - done, - *transition[TransitionIndex.TRUNCATED :], - ) - return transition + done = transition.get(TransitionKey.DONE) + processed_done = self.done(done) + # Create a new transition dict with the processed done flag + new_transition = transition.copy() + new_transition[TransitionKey.DONE] = processed_done + return new_transition class TruncatedProcessor: @@ -857,7 +854,7 @@ class TruncatedProcessor: Subclasses should override the `truncated` method to implement custom truncated flag processing. This class handles the boilerplate of extracting and reinserting the processed truncated flag - into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. Example: ```python @@ -870,7 +867,7 @@ class TruncatedProcessor: return truncated or some_condition > self.threshold ``` - By inheriting from this class, you avoid writing repetitive code to handle transition tuple + By inheriting from this class, you avoid writing repetitive code to handle transition dict manipulation, focusing only on the specific truncated flag processing logic. """ @@ -886,17 +883,12 @@ class TruncatedProcessor: return truncated def __call__(self, transition: EnvTransition) -> EnvTransition: - truncated = transition[TransitionIndex.TRUNCATED] - truncated = self.truncated(truncated) - transition = ( - transition[TransitionIndex.OBSERVATION], - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - truncated, - *transition[TransitionIndex.INFO :], - ) - return transition + truncated = transition.get(TransitionKey.TRUNCATED) + processed_truncated = self.truncated(truncated) + # Create a new transition dict with the processed truncated flag + new_transition = transition.copy() + new_transition[TransitionKey.TRUNCATED] = processed_truncated + return new_transition class InfoProcessor: @@ -904,7 +896,7 @@ class InfoProcessor: Subclasses should override the `info` method to implement custom info processing. This class handles the boilerplate of extracting and reinserting the processed info - into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. Example: ```python @@ -922,7 +914,7 @@ class InfoProcessor: self.step_count = 0 ``` - By inheriting from this class, you avoid writing repetitive code to handle transition tuple + By inheriting from this class, you avoid writing repetitive code to handle transition dict manipulation, focusing only on the specific info dictionary processing logic. """ @@ -938,18 +930,12 @@ class InfoProcessor: return info def __call__(self, transition: EnvTransition) -> EnvTransition: - info = transition[TransitionIndex.INFO] - info = self.info(info) - transition = ( - transition[TransitionIndex.OBSERVATION], - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - info, - *transition[TransitionIndex.COMPLEMENTARY_DATA :], - ) - return transition + info = transition.get(TransitionKey.INFO) + processed_info = self.info(info) + # Create a new transition dict with the processed info + new_transition = transition.copy() + new_transition[TransitionKey.INFO] = processed_info + return new_transition class ComplementaryDataProcessor: @@ -957,7 +943,7 @@ class ComplementaryDataProcessor: Subclasses should override the `complementary_data` method to implement custom complementary data processing. This class handles the boilerplate of extracting and reinserting the processed complementary data - into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. """ def complementary_data(self, complementary_data): @@ -972,18 +958,12 @@ class ComplementaryDataProcessor: return complementary_data def __call__(self, transition: EnvTransition) -> EnvTransition: - complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA] - complementary_data = self.complementary_data(complementary_data) - transition = ( - transition[TransitionIndex.OBSERVATION], - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - complementary_data, - ) - return transition + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + processed_complementary_data = self.complementary_data(complementary_data) + # Create a new transition dict with the processed complementary data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data + return new_transition class IdentityProcessor: diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 0eb3d0b98..08855e237 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -18,7 +18,7 @@ from typing import Any import torch -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey @dataclass @@ -29,7 +29,7 @@ class RenameProcessor: rename_map: dict[str, str] = field(default_factory=dict) def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition[TransitionIndex.OBSERVATION] + observation = transition.get(TransitionKey.OBSERVATION) if observation is None: return transition @@ -39,15 +39,11 @@ class RenameProcessor: processed_obs[self.rename_map[key]] = value else: processed_obs[key] = value - return ( - processed_obs, - transition[TransitionIndex.ACTION], - transition[TransitionIndex.REWARD], - transition[TransitionIndex.DONE], - transition[TransitionIndex.TRUNCATED], - transition[TransitionIndex.INFO], - transition[TransitionIndex.COMPLEMENTARY_DATA], - ) + + # Create a new transition with the renamed observation + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition def get_config(self) -> dict[str, Any]: return {"rename_map": self.rename_map} diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index c8e1a80cc..b2e357645 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -72,7 +72,7 @@ from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters -from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor +from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -160,7 +160,7 @@ def rollout( # Numpy array to tensor and changing dictionary keys to LeRobot policy format. transition = (observation, None, None, None, None, None, None) processed_transition = obs_processor(transition) - observation = processed_transition[TransitionIndex.OBSERVATION] + observation = processed_transition[TransitionKey.OBSERVATION] if return_observations: all_observations.append(deepcopy(observation)) @@ -211,7 +211,7 @@ def rollout( if return_observations: transition = (observation, None, None, None, None, None, None) processed_transition = obs_processor(transition) - observation = processed_transition[TransitionIndex.OBSERVATION] + observation = processed_transition[TransitionKey.OBSERVATION] all_observations.append(deepcopy(observation)) # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 15ce1f933..e3f50c74c 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -22,7 +22,7 @@ from gymnasium.utils.env_checker import check_env import lerobot from lerobot.envs.factory import make_env, make_env_config -from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor +from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor from tests.utils import require_env OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] @@ -53,7 +53,7 @@ def test_factory(env_name): obs_processor = RobotProcessor([VanillaObservationProcessor()]) transition = (obs, None, None, None, None, None, None) processed_transition = obs_processor(transition) - obs = processed_transition[TransitionIndex.OBSERVATION] + obs = processed_transition[TransitionKey.OBSERVATION] # test image keys are float32 in range [0,1] for key in obs: diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 44751a829..0179ce331 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -39,7 +39,7 @@ from lerobot.policies.factory import ( ) from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor import RobotProcessor, TransitionIndex, VanillaObservationProcessor +from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -188,7 +188,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): obs_processor = RobotProcessor([VanillaObservationProcessor()]) transition = (observation, None, None, None, None, None, None) processed_transition = obs_processor(transition) - observation = processed_transition[TransitionIndex.OBSERVATION] + observation = processed_transition[TransitionKey.OBSERVATION] # send observation to device/gpu observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation} diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index de2ca0e7d..63894025d 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,7 +2,7 @@ import torch from lerobot.processor.pipeline import ( RobotProcessor, - TransitionIndex, + TransitionKey, _default_batch_to_transition, _default_transition_to_batch, ) @@ -63,27 +63,27 @@ def test_batch_to_transition_observation_grouping(): transition = _default_batch_to_transition(batch) # Check observation is a dict with all observation.* keys - assert isinstance(transition[TransitionIndex.OBSERVATION], dict) - assert "observation.image.top" in transition[TransitionIndex.OBSERVATION] - assert "observation.image.left" in transition[TransitionIndex.OBSERVATION] - assert "observation.state" in transition[TransitionIndex.OBSERVATION] + assert isinstance(transition[TransitionKey.OBSERVATION], dict) + assert "observation.image.top" in transition[TransitionKey.OBSERVATION] + assert "observation.image.left" in transition[TransitionKey.OBSERVATION] + assert "observation.state" in transition[TransitionKey.OBSERVATION] # Check values are preserved assert torch.allclose( - transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] ) assert torch.allclose( - transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] ) - assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] # Check other fields - assert transition[TransitionIndex.ACTION] == "action_data" - assert transition[TransitionIndex.REWARD] == 1.5 - assert transition[TransitionIndex.DONE] - assert not transition[TransitionIndex.TRUNCATED] - assert transition[TransitionIndex.INFO] == {"episode": 42} - assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {} + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 1.5 + assert transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"episode": 42} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} def test_transition_to_batch_observation_flattening(): @@ -94,15 +94,15 @@ def test_transition_to_batch_observation_flattening(): "observation.state": [1, 2, 3, 4], } - transition = ( - observation_dict, # observation - "action_data", # action - 1.5, # reward - True, # done - False, # truncated - {"episode": 42}, # info - {}, # complementary_data - ) + transition = { + TransitionKey.OBSERVATION: observation_dict, + TransitionKey.ACTION: "action_data", + TransitionKey.REWARD: 1.5, + TransitionKey.DONE: True, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {"episode": 42}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } batch = _default_transition_to_batch(transition) @@ -137,14 +137,14 @@ def test_no_observation_keys(): transition = _default_batch_to_transition(batch) # Observation should be None when no observation.* keys - assert transition[TransitionIndex.OBSERVATION] is None + assert transition[TransitionKey.OBSERVATION] is None # Check other fields - assert transition[TransitionIndex.ACTION] == "action_data" - assert transition[TransitionIndex.REWARD] == 2.0 - assert not transition[TransitionIndex.DONE] - assert transition[TransitionIndex.TRUNCATED] - assert transition[TransitionIndex.INFO] == {"test": "no_obs"} + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 2.0 + assert not transition[TransitionKey.DONE] + assert transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"test": "no_obs"} # Round trip should work reconstructed_batch = _default_transition_to_batch(transition) @@ -162,15 +162,15 @@ def test_minimal_batch(): transition = _default_batch_to_transition(batch) # Check observation - assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"} - assert transition[TransitionIndex.ACTION] == "minimal_action" + assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionKey.ACTION] == "minimal_action" # Check defaults - assert transition[TransitionIndex.REWARD] == 0.0 - assert not transition[TransitionIndex.DONE] - assert not transition[TransitionIndex.TRUNCATED] - assert transition[TransitionIndex.INFO] == {} - assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {} + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip reconstructed_batch = _default_transition_to_batch(transition) @@ -189,13 +189,13 @@ def test_empty_batch(): transition = _default_batch_to_transition(batch) # All fields should have defaults - assert transition[TransitionIndex.OBSERVATION] is None - assert transition[TransitionIndex.ACTION] is None - assert transition[TransitionIndex.REWARD] == 0.0 - assert not transition[TransitionIndex.DONE] - assert not transition[TransitionIndex.TRUNCATED] - assert transition[TransitionIndex.INFO] == {} - assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {} + assert transition[TransitionKey.OBSERVATION] is None + assert transition[TransitionKey.ACTION] is None + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip reconstructed_batch = _default_transition_to_batch(transition) @@ -256,33 +256,27 @@ def test_custom_converter(): # Custom converter that modifies the reward tr = _default_batch_to_transition(batch) # Double the reward - reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0 - return ( - tr[TransitionIndex.OBSERVATION], - tr[TransitionIndex.ACTION], - reward, - tr[TransitionIndex.DONE], - tr[TransitionIndex.TRUNCATED], - tr[TransitionIndex.INFO], - tr[TransitionIndex.COMPLEMENTARY_DATA], - ) + reward = tr.get(TransitionKey.REWARD, 0.0) + new_tr = tr.copy() + new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0 + return new_tr def to_batch(tr): - # Custom converter that adds a custom field batch = _default_transition_to_batch(tr) - batch["custom_field"] = "custom_value" return batch - proc = RobotProcessor([], to_transition=to_tr, to_output=to_batch) - batch = _dummy_batch() - out = proc(batch) + processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) - # Check that custom modifications were applied - assert out["next.reward"] == batch["next.reward"] * 2 - assert out["custom_field"] == "custom_value" + batch = { + "observation.state": torch.randn(1, 4), + "action": torch.randn(1, 2), + "next.reward": 1.0, + "next.done": False, + } - # Check that observation.* keys are still preserved - original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} - output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")} + result = processor(batch) - assert set(original_obs_keys.keys()) == set(output_obs_keys.keys()) + # Check the reward was doubled by our custom converter + assert result["next.reward"] == 2.0 + assert torch.allclose(result["observation.state"], batch["observation.state"]) + assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 3aabbe532..26aea56c7 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -1,3 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from unittest.mock import Mock import numpy as np @@ -10,7 +25,22 @@ from lerobot.processor.normalize_processor import ( UnnormalizerProcessor, _convert_stats_to_tensors, ) -from lerobot.processor.pipeline import RobotProcessor, TransitionIndex +from lerobot.processor.pipeline import RobotProcessor, TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } def test_numpy_conversion(): @@ -120,10 +150,10 @@ def test_mean_std_normalization(observation_normalizer): "observation.image": torch.tensor([0.7, 0.5, 0.3]), "observation.state": torch.tensor([0.5, 0.0]), } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) normalized_transition = observation_normalizer(transition) - normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Check mean/std normalization expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 @@ -134,10 +164,10 @@ def test_min_max_normalization(observation_normalizer): observation = { "observation.state": torch.tensor([0.5, 0.0]), } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) normalized_transition = observation_normalizer(transition) - normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Check min/max normalization to [-1, 1] # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 @@ -157,10 +187,10 @@ def test_selective_normalization(observation_stats): "observation.image": torch.tensor([0.7, 0.5, 0.3]), "observation.state": torch.tensor([0.5, 0.0]), } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) normalized_transition = normalizer(transition) - normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Only image should be normalized assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) @@ -176,10 +206,10 @@ def test_device_compatibility(observation_stats): observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) normalized_transition = normalizer(transition) - normalized_obs = normalized_transition[TransitionIndex.OBSERVATION] + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] assert normalized_obs["observation.image"].device.type == "cuda" @@ -220,10 +250,10 @@ def test_state_dict_save_load(observation_normalizer): # Test that it works the same observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) - result1 = observation_normalizer(transition)[0] - result2 = new_normalizer(transition)[0] + result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] + result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] assert torch.allclose(result1["observation.image"], result2["observation.image"]) @@ -271,10 +301,10 @@ def test_mean_std_unnormalization(action_stats_mean_std): ) normalized_action = torch.tensor([1.0, -0.5, 2.0]) - transition = (None, normalized_action, None, None, None, None, None) + transition = create_transition(action=normalized_action) unnormalized_transition = unnormalizer(transition) - unnormalized_action = unnormalized_transition[TransitionIndex.ACTION] + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] # action * std + mean expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0]) @@ -290,10 +320,10 @@ def test_min_max_unnormalization(action_stats_min_max): # Actions in [-1, 1] normalized_action = torch.tensor([0.0, -1.0, 1.0]) - transition = (None, normalized_action, None, None, None, None, None) + transition = create_transition(action=normalized_action) unnormalized_transition = unnormalizer(transition) - unnormalized_action = unnormalized_transition[TransitionIndex.ACTION] + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] # Map from [-1, 1] to [min, max] # (action + 1) / 2 * (max - min) + min @@ -315,10 +345,10 @@ def test_numpy_action_input(action_stats_mean_std): ) normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) - transition = (None, normalized_action, None, None, None, None, None) + transition = create_transition(action=normalized_action) unnormalized_transition = unnormalizer(transition) - unnormalized_action = unnormalized_transition[TransitionIndex.ACTION] + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] assert isinstance(unnormalized_action, torch.Tensor) expected = torch.tensor([1.0, -1.0, 1.0]) @@ -332,7 +362,7 @@ def test_none_action(action_stats_mean_std): features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) - transition = (None, None, None, None, None, None, None) + transition = create_transition() result = unnormalizer(transition) # Should return transition unchanged @@ -396,23 +426,31 @@ def test_combined_normalization(normalizer_processor): "observation.state": torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) - transition = (observation, action, 1.0, False, False, {}, {}) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) processed_transition = normalizer_processor(transition) # Check normalized observations - processed_obs = processed_transition[TransitionIndex.OBSERVATION] + processed_obs = processed_transition[TransitionKey.OBSERVATION] expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 assert torch.allclose(processed_obs["observation.image"], expected_image) # Check normalized action - processed_action = processed_transition[TransitionIndex.ACTION] + processed_action = processed_transition[TransitionKey.ACTION] expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0]) assert torch.allclose(processed_action, expected_action) # Check other fields remain unchanged - assert processed_transition[TransitionIndex.REWARD] == 1.0 - assert not processed_transition[TransitionIndex.DONE] + assert processed_transition[TransitionKey.REWARD] == 1.0 + assert not processed_transition[TransitionKey.DONE] def test_processor_from_lerobot_dataset(full_stats): @@ -466,13 +504,21 @@ def test_integration_with_robot_processor(normalizer_processor): "observation.state": torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) - transition = (observation, action, 1.0, False, False, {}, {}) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) processed_transition = robot_processor(transition) # Verify the processing worked - assert isinstance(processed_transition[TransitionIndex.OBSERVATION], dict) - assert isinstance(processed_transition[TransitionIndex.ACTION], torch.Tensor) + assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict) + assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor) # Edge case tests @@ -482,7 +528,7 @@ def test_empty_observation(): norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) - transition = (None, None, None, None, None, None, None) + transition = create_transition() result = normalizer(transition) assert result == transition @@ -493,11 +539,13 @@ def test_empty_stats(): norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) observation = {"observation.image": torch.tensor([0.5])} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = normalizer(transition) # Should return observation unchanged since no stats are available - assert torch.allclose(result[0]["observation.image"], observation["observation.image"]) + assert torch.allclose( + result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] + ) def test_partial_stats(): @@ -507,9 +555,9 @@ def test_partial_stats(): norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) observation = {"observation.image": torch.tensor([0.7])} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) - processed = normalizer(transition)[TransitionIndex.OBSERVATION] + processed = normalizer(transition)[TransitionKey.OBSERVATION] assert torch.allclose(processed["observation.image"], observation["observation.image"]) @@ -551,14 +599,25 @@ def test_serialization_roundtrip(full_stats): "observation.state": torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) - transition = (observation, action, 1.0, False, False, {}, {}) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) result1 = original_processor(transition) result2 = new_processor(transition) # Compare results - assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"]) - assert torch.allclose(result1[1], result2[1]) + assert torch.allclose( + result1[TransitionKey.OBSERVATION]["observation.image"], + result2[TransitionKey.OBSERVATION]["observation.image"], + ) + assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # Verify features and norm_map are correctly reconstructed assert new_processor.features.keys() == original_processor.features.keys() diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 5e06fd7fa..5026a9177 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -23,6 +23,22 @@ from lerobot.processor import ( StateProcessor, VanillaObservationProcessor, ) +from lerobot.processor.pipeline import TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } def test_process_single_image(): @@ -33,10 +49,10 @@ def test_process_single_image(): image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) observation = {"pixels": image} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that the image was processed correctly assert "observation.image" in processed_obs @@ -60,10 +76,10 @@ def test_process_image_dict(): image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) observation = {"pixels": {"camera1": image1, "camera2": image2}} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that both images were processed assert "observation.images.camera1" in processed_obs @@ -82,10 +98,10 @@ def test_process_batched_image(): image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) observation = {"pixels": image} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimension is preserved assert processed_obs["observation.image"].shape == (2, 3, 64, 64) @@ -98,7 +114,7 @@ def test_invalid_image_format(): # Test wrong channel order (channels first) image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) observation = {"pixels": image} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) with pytest.raises(ValueError, match="Expected channel-last images"): processor(transition) @@ -111,7 +127,7 @@ def test_invalid_image_dtype(): # Test wrong dtype image = np.random.rand(64, 64, 3).astype(np.float32) observation = {"pixels": image} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) with pytest.raises(ValueError, match="Expected torch.uint8 images"): processor(transition) @@ -122,10 +138,10 @@ def test_no_pixels_in_observation(): processor = ImageProcessor() observation = {"other_data": np.array([1, 2, 3])} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Should preserve other data unchanged assert "other_data" in processed_obs @@ -136,7 +152,7 @@ def test_none_observation(): """Test processor with None observation.""" processor = ImageProcessor() - transition = (None, None, None, None, None, None, None) + transition = create_transition() result = processor(transition) assert result == transition @@ -167,10 +183,10 @@ def test_process_environment_state(): env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) observation = {"environment_state": env_state} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that environment_state was renamed and processed assert "observation.environment_state" in processed_obs @@ -188,10 +204,10 @@ def test_process_agent_pos(): agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) observation = {"agent_pos": agent_pos} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that agent_pos was renamed and processed assert "observation.state" in processed_obs @@ -211,10 +227,10 @@ def test_process_batched_states(): agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) observation = {"environment_state": env_state, "agent_pos": agent_pos} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimensions are preserved assert processed_obs["observation.environment_state"].shape == (2, 2) @@ -229,10 +245,10 @@ def test_process_both_states(): agent_pos = np.array([0.5, -0.5], dtype=np.float32) observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that both states were processed assert "observation.environment_state" in processed_obs @@ -251,10 +267,10 @@ def test_no_states_in_observation(): processor = StateProcessor() observation = {"other_data": np.array([1, 2, 3])} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Should preserve data unchanged np.testing.assert_array_equal(processed_obs, observation) @@ -275,10 +291,10 @@ def test_complete_observation_processing(): "agent_pos": agent_pos, "other_data": "preserve_me", } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] # Check that image was processed assert "observation.image" in processed_obs @@ -303,10 +319,10 @@ def test_image_only_processing(): image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) observation = {"pixels": image} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] assert "observation.image" in processed_obs assert len(processed_obs) == 1 @@ -318,10 +334,10 @@ def test_state_only_processing(): agent_pos = np.array([1.0, 2.0], dtype=np.float32) observation = {"agent_pos": agent_pos} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] assert "observation.state" in processed_obs assert "agent_pos" not in processed_obs @@ -332,10 +348,10 @@ def test_empty_observation(): processor = VanillaObservationProcessor() observation = {} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[0] + processed_obs = result[TransitionKey.OBSERVATION] assert processed_obs == {} @@ -369,8 +385,8 @@ def test_equivalent_to_original_function(): original_result = preprocess_observation(observation) # Process with new processor - transition = (observation, None, None, None, None, None, None) - processor_result = processor(transition)[0] + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] # Compare results assert set(original_result.keys()) == set(processor_result.keys()) @@ -396,8 +412,8 @@ def test_equivalent_with_image_dict(): original_result = preprocess_observation(observation) # Process with new processor - transition = (observation, None, None, None, None, None, None) - processor_result = processor(transition)[0] + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] # Compare results assert set(original_result.keys()) == set(processor_result.keys()) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 801e3270a..a21e229dd 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -18,7 +18,7 @@ import json import tempfile from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict +from typing import Any import numpy as np import pytest @@ -26,6 +26,22 @@ import torch import torch.nn as nn from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor +from lerobot.processor.pipeline import TransitionKey + + +def create_transition( + observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } @dataclass @@ -45,14 +61,16 @@ class MockStep: def __call__(self, transition: EnvTransition) -> EnvTransition: """Add a counter to the complementary_data.""" - obs, action, reward, done, truncated, info, comp_data = transition - + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) comp_data = {} if comp_data is None else dict(comp_data) # Make a copy comp_data[f"{self.name}_counter"] = self.counter self.counter += 1 - return (obs, action, reward, done, truncated, info, comp_data) + # Create a new transition with updated complementary_data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition def get_config(self) -> dict[str, Any]: # Return all JSON-serializable attributes that should be persisted @@ -79,12 +97,14 @@ class MockStepWithoutOptionalMethods: def __call__(self, transition: EnvTransition) -> EnvTransition: """Multiply reward by multiplier.""" - obs, action, reward, done, truncated, info, comp_data = transition + reward = transition.get(TransitionKey.REWARD) if reward is not None: - reward = reward * self.multiplier + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = reward * self.multiplier + return new_transition - return (obs, action, reward, done, truncated, info, comp_data) + return transition @dataclass @@ -105,7 +125,7 @@ class MockStepWithTensorState: def __call__(self, transition: EnvTransition) -> EnvTransition: """Update running statistics.""" - obs, action, reward, done, truncated, info, comp_data = transition + reward = transition.get(TransitionKey.REWARD) if reward is not None: # Update running mean @@ -143,7 +163,7 @@ def test_empty_pipeline(): """Test pipeline with no steps.""" pipeline = RobotProcessor() - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() result = pipeline(transition) assert result == transition @@ -155,15 +175,15 @@ def test_single_step_pipeline(): step = MockStep("test_step") pipeline = RobotProcessor([step]) - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() result = pipeline(transition) assert len(pipeline) == 1 - assert result[6]["test_step_counter"] == 0 # complementary_data + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 # Call again to test counter increment result = pipeline(transition) - assert result[6]["test_step_counter"] == 1 + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1 def test_multiple_steps_pipeline(): @@ -172,46 +192,46 @@ def test_multiple_steps_pipeline(): step2 = MockStep("step2") pipeline = RobotProcessor([step1, step2]) - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() result = pipeline(transition) assert len(pipeline) == 2 - assert result[6]["step1_counter"] == 0 - assert result[6]["step2_counter"] == 0 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0 def test_invalid_transition_format(): """Test pipeline with invalid transition format.""" pipeline = RobotProcessor([MockStep()]) - # Test with wrong number of elements - with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"): - pipeline((None, None, 0.0)) # Only 3 elements + # Test with wrong type (tuple instead of dict) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict - # Test with wrong type - with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"): - pipeline("not a tuple") + # Test with wrong type (string) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline("not a dict") def test_step_through(): - """Test step_through method with tuple input.""" + """Test step_through method with dict input.""" step1 = MockStep("step1") step2 = MockStep("step2") pipeline = RobotProcessor([step1, step2]) - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() results = list(pipeline.step_through(transition)) assert len(results) == 3 # Original + 2 steps assert results[0] == transition # Original - assert "step1_counter" in results[1][6] # After step1 - assert "step2_counter" in results[2][6] # After step2 + assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1 + assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2 - # Ensure all results are tuples (same format as input) + # Ensure all results are dicts (same format as input) for result in results: - assert isinstance(result, tuple) - assert len(result) == 7 + assert isinstance(result, dict) + assert all(isinstance(k, TransitionKey) for k in result.keys()) def test_step_through_with_dict(): @@ -279,7 +299,7 @@ def test_hooks(): pipeline.register_before_step_hook(before_hook) pipeline.register_after_step_hook(after_hook) - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() pipeline(transition) assert before_calls == [0] @@ -292,15 +312,16 @@ def test_hook_modification(): pipeline = RobotProcessor([step]) def modify_reward_hook(idx: int, transition: EnvTransition): - obs, action, reward, done, truncated, info, comp_data = transition - return (obs, action, 42.0, done, truncated, info, comp_data) + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = 42.0 + return new_transition pipeline.register_before_step_hook(modify_reward_hook) - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() result = pipeline(transition) - assert result[2] == 42.0 # reward modified by hook + assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook def test_reset(): @@ -316,7 +337,7 @@ def test_reset(): pipeline.register_reset_hook(reset_hook) # Make some calls to increment counter - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() pipeline(transition) pipeline(transition) @@ -335,7 +356,7 @@ def test_profile_steps(): step2 = MockStep("step2") pipeline = RobotProcessor([step1, step2]) - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() profile_results = pipeline.profile_steps(transition, num_runs=10) @@ -397,10 +418,10 @@ def test_step_without_optional_methods(): step = MockStepWithoutOptionalMethods(multiplier=3.0) pipeline = RobotProcessor([step]) - transition = (None, None, 2.0, False, False, {}, {}) + transition = create_transition(reward=2.0) result = pipeline(transition) - assert result[2] == 6.0 # 2.0 * 3.0 + assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0 # Reset should work even if step doesn't implement reset pipeline.reset() @@ -419,7 +440,7 @@ def test_mixed_json_and_tensor_state(): # Process some transitions with rewards for i in range(10): - transition = (None, None, float(i), False, False, {}, {}) + transition = create_transition(reward=float(i)) pipeline(transition) # Check state @@ -466,7 +487,7 @@ class MockModuleStep(nn.Module): def __call__(self, transition: EnvTransition) -> EnvTransition: """Process transition and update running mean.""" - obs, action, reward, done, truncated, info, comp_data = transition + obs = transition.get(TransitionKey.OBSERVATION) if obs is not None and isinstance(obs, torch.Tensor): # Process observation through linear layer @@ -509,7 +530,7 @@ def test_to_device_with_state_dict(): # Process some transitions to populate state for i in range(10): - transition = (None, None, float(i), False, False, {}, {}) + transition = create_transition(reward=float(i)) pipeline(transition) # Check initial device (should be CPU) @@ -551,7 +572,7 @@ def test_to_device_with_module(): # Process some data obs = torch.randn(2, 5) - transition = (obs, None, 1.0, False, False, {}, {}) + transition = create_transition(observation=obs, reward=1.0) pipeline(transition) # Check initial device @@ -575,7 +596,7 @@ def test_to_device_with_module(): # Verify the module still works after transfer obs_cuda = torch.randn(2, 5, device="cuda:0") - transition = (obs_cuda, None, 1.0, False, False, {}, {}) + transition = create_transition(observation=obs_cuda, reward=1.0) pipeline(transition) # Should not raise an error @@ -589,7 +610,7 @@ def test_to_device_mixed_steps(): # Process some data for i in range(5): - transition = (torch.randn(2, 10), None, float(i), False, False, {}, {}) + transition = create_transition(observation=torch.randn(2, 10), reward=float(i)) pipeline(transition) # Check initial state @@ -630,7 +651,7 @@ def test_to_device_preserves_functionality(): # Process initial data rewards = [1.0, 2.0, 3.0] for r in rewards: - transition = (None, None, r, False, False, {}, {}) + transition = create_transition(reward=r) pipeline(transition) # Check state before transfer @@ -645,7 +666,7 @@ def test_to_device_preserves_functionality(): assert step.running_count == initial_count # Process more data to ensure functionality - transition = (None, None, 4.0, False, False, {}, {}) + transition = create_transition(reward=4.0) _ = pipeline(transition) assert step.running_count == 4 @@ -700,7 +721,8 @@ class MockNonModuleStepWithState: def __call__(self, transition: EnvTransition) -> EnvTransition: """Process transition using tensor operations.""" - obs, action, reward, done, truncated, info, comp_data = transition + obs = transition.get(TransitionKey.OBSERVATION) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim: # Perform some tensor operations @@ -718,7 +740,12 @@ class MockNonModuleStepWithState: comp_data[f"{self.name}_mean_output"] = output.mean().item() comp_data[f"{self.name}_steps"] = self.step_count.item() - return (obs, action, reward, done, truncated, info, comp_data) + # Return updated transition + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + return transition def get_config(self) -> dict[str, Any]: return { @@ -763,9 +790,9 @@ def test_to_device_non_module_class(): # Process some data to populate state for i in range(3): obs = torch.randn(2, 5) - transition = (obs, None, float(i), False, False, {}, {}) + transition = create_transition(observation=obs, reward=float(i)) result = pipeline(transition) - comp_data = result[6] + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert f"{non_module_step.name}_steps" in comp_data # Verify all tensors are on CPU initially @@ -811,9 +838,9 @@ def test_to_device_non_module_class(): # Test that step still works on GPU obs_gpu = torch.randn(2, 5, device="cuda") - transition = (obs_gpu, None, 1.0, False, False, {}, {}) + transition = create_transition(observation=obs_gpu, reward=1.0) result = pipeline(transition) - comp_data = result[6] + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] # Verify processing worked assert comp_data[f"{non_module_step.name}_steps"] == 4 @@ -835,7 +862,7 @@ def test_to_device_module_vs_non_module(): # Process some data obs = torch.randn(2, 5) - transition = (obs, None, 1.0, False, False, {}, {}) + transition = create_transition(observation=obs, reward=1.0) _ = pipeline(transition) # Check initial devices @@ -860,7 +887,7 @@ def test_to_device_module_vs_non_module(): # Process data on GPU obs_gpu = torch.randn(2, 5, device="cuda") - transition = (obs_gpu, None, 2.0, False, False, {}, {}) + transition = create_transition(observation=obs_gpu, reward=2.0) _ = pipeline(transition) # Verify both steps processed the data @@ -889,7 +916,8 @@ class MockStepWithNonSerializableParam: self.env = env # Non-serializable parameter (like gym.Env) def __call__(self, transition: EnvTransition) -> EnvTransition: - obs, action, reward, done, truncated, info, comp_data = transition + reward = transition.get(TransitionKey.REWARD) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) # Use the env parameter if provided if self.env is not None: @@ -897,10 +925,14 @@ class MockStepWithNonSerializableParam: comp_data[f"{self.name}_env_info"] = str(self.env) # Apply multiplier to reward + new_transition = transition.copy() if reward is not None: - reward = reward * self.multiplier + new_transition[TransitionKey.REWARD] = reward * self.multiplier - return (obs, action, reward, done, truncated, info, comp_data) + if comp_data: + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + + return new_transition def get_config(self) -> dict[str, Any]: # Note: env is intentionally NOT included here as it's not serializable @@ -928,13 +960,15 @@ class RegisteredMockStep: device: str = "cpu" def __call__(self, transition: EnvTransition) -> EnvTransition: - obs, action, reward, done, truncated, info, comp_data = transition + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) comp_data = {} if comp_data is None else dict(comp_data) comp_data["registered_step_value"] = self.value comp_data["registered_step_device"] = self.device - return (obs, action, reward, done, truncated, info, comp_data) + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition def get_config(self) -> dict[str, Any]: return { @@ -993,18 +1027,18 @@ def test_from_pretrained_with_overrides(): assert loaded_pipeline.name == "TestOverrides" # Test the loaded steps - transition = (None, None, 1.0, False, False, {}, {}) + transition = create_transition(reward=1.0) result = loaded_pipeline(transition) # Check that overrides were applied - comp_data = result[6] + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert "env_step_env_info" in comp_data assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)" assert comp_data["registered_step_value"] == 200 assert comp_data["registered_step_device"] == "cuda" # Check that multiplier override was applied - assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier) + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier) def test_from_pretrained_with_partial_overrides(): @@ -1024,13 +1058,13 @@ def test_from_pretrained_with_partial_overrides(): # Both steps will get the override loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) - transition = (None, None, 1.0, False, False, {}, {}) + transition = create_transition(reward=1.0) result = loaded_pipeline(transition) # The reward should be affected by both steps, both getting the override # First step: 1.0 * 5.0 = 5.0 (overridden) # Second step: 5.0 * 5.0 = 25.0 (also overridden) - assert result[2] == 25.0 + assert result[TransitionKey.REWARD] == 25.0 def test_from_pretrained_invalid_override_key(): @@ -1082,10 +1116,10 @@ def test_from_pretrained_registered_step_override(): loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) # Test that overrides were applied - transition = (None, None, 0.0, False, False, {}, {}) + transition = create_transition() result = loaded_pipeline(transition) - comp_data = result[6] + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert comp_data["registered_step_value"] == 999 assert comp_data["registered_step_device"] == "cuda" @@ -1110,13 +1144,13 @@ def test_from_pretrained_mixed_registered_and_unregistered(): loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) # Test both steps - transition = (None, None, 2.0, False, False, {}, {}) + transition = create_transition(reward=2.0) result = loaded_pipeline(transition) - comp_data = result[6] + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)" assert comp_data["registered_step_value"] == 777 - assert result[2] == 8.0 # 2.0 * 4.0 + assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0 def test_from_pretrained_no_overrides(): @@ -1133,10 +1167,10 @@ def test_from_pretrained_no_overrides(): assert len(loaded_pipeline) == 1 # Test that the step works (env will be None) - transition = (None, None, 1.0, False, False, {}, {}) + transition = create_transition(reward=1.0) result = loaded_pipeline(transition) - assert result[2] == 3.0 # 1.0 * 3.0 + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 def test_from_pretrained_empty_overrides(): @@ -1153,10 +1187,10 @@ def test_from_pretrained_empty_overrides(): assert len(loaded_pipeline) == 1 # Test that the step works normally - transition = (None, None, 1.0, False, False, {}, {}) + transition = create_transition(reward=1.0) result = loaded_pipeline(transition) - assert result[2] == 2.0 + assert result[TransitionKey.REWARD] == 2.0 def test_from_pretrained_override_instantiation_error(): @@ -1185,7 +1219,7 @@ def test_from_pretrained_with_state_and_overrides(): # Process some data to create state for i in range(10): - transition = (None, None, float(i), False, False, {}, {}) + transition = create_transition(reward=float(i)) pipeline(transition) with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index b9564bbb5..2636692c1 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -13,14 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import tempfile from pathlib import Path import numpy as np import torch -from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionIndex +from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } def test_basic_renaming(): @@ -36,10 +50,10 @@ def test_basic_renaming(): "old_key2": np.array([3.0, 4.0]), "unchanged_key": "keep_me", } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] # Check renamed keys assert "new_key1" in processed_obs @@ -63,10 +77,10 @@ def test_empty_rename_map(): "key1": torch.tensor([1.0]), "key2": "value2", } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] # All keys should be unchanged assert processed_obs.keys() == observation.keys() @@ -78,7 +92,7 @@ def test_none_observation(): """Test processor with None observation.""" processor = RenameProcessor(rename_map={"old": "new"}) - transition = (None, None, None, None, None, None, None) + transition = create_transition() result = processor(transition) # Should return transition unchanged @@ -98,10 +112,10 @@ def test_overlapping_rename(): "b": 2, "x": 3, } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] # Check that renaming happens correctly assert "a" not in processed_obs @@ -124,10 +138,10 @@ def test_partial_rename(): "reward": 1.0, "info": {"episode": 1}, } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] # Check renamed keys assert "observation.proprio_state" in processed_obs @@ -178,10 +192,12 @@ def test_integration_with_robot_processor(): "pixels": np.zeros((32, 32, 3), dtype=np.uint8), "other_data": "preserve_me", } - transition = (observation, None, 0.5, False, False, {}, {}) + transition = create_transition( + observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={} + ) result = pipeline(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] # Check renaming worked through pipeline assert "observation.state" in processed_obs @@ -191,8 +207,8 @@ def test_integration_with_robot_processor(): assert processed_obs["other_data"] == "preserve_me" # Check other transition elements unchanged - assert result[TransitionIndex.REWARD] == 0.5 - assert result[TransitionIndex.DONE] is False + assert result[TransitionKey.REWARD] == 0.5 + assert result[TransitionKey.DONE] is False def test_save_and_load_pretrained(): @@ -229,10 +245,10 @@ def test_save_and_load_pretrained(): # Test functionality after loading observation = {"old_state": [1, 2, 3], "old_image": "image_data"} - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = loaded_pipeline(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] assert "observation.state" in processed_obs assert "observation.image" in processed_obs @@ -306,17 +322,17 @@ def test_chained_rename_processors(): "img": "image_data", "extra": "keep_me", } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) # Step through to see intermediate results results = list(pipeline.step_through(transition)) # After first processor - assert "agent_position" in results[1][TransitionIndex.OBSERVATION] - assert "camera_image" in results[1][TransitionIndex.OBSERVATION] + assert "agent_position" in results[1][TransitionKey.OBSERVATION] + assert "camera_image" in results[1][TransitionKey.OBSERVATION] # After second processor - final_obs = results[2][TransitionIndex.OBSERVATION] + final_obs = results[2][TransitionKey.OBSERVATION] assert "observation.state" in final_obs assert "observation.image" in final_obs assert final_obs["extra"] == "keep_me" @@ -343,10 +359,10 @@ def test_nested_observation_rename(): "observation.proprio": torch.randn(7), "observation.gripper": torch.tensor([0.0]), # Not renamed } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] # Check renames assert "observation.camera.left_view" in processed_obs @@ -378,10 +394,10 @@ def test_value_types_preserved(): "old_dict": {"nested": "value"}, "old_list": [1, 2, 3], } - transition = (observation, None, None, None, None, None, None) + transition = create_transition(observation=observation) result = processor(transition) - processed_obs = result[TransitionIndex.OBSERVATION] + processed_obs = result[TransitionKey.OBSERVATION] # Check that values and types are preserved assert torch.equal(processed_obs["new_tensor"], tensor_value)