diff --git a/docs/source/processor_tutorial.mdx b/docs/source/processor_tutorial.mdx index 7646a795e..a6e143725 100644 --- a/docs/source/processor_tutorial.mdx +++ b/docs/source/processor_tutorial.mdx @@ -60,7 +60,7 @@ If you've worked with robot learning before, you've likely written preprocessing # Traditional procedural approach - hard to maintain and share def preprocess_observation(obs, device='cuda'): processed_obs = {} - + # Process images if "pixels" in obs: for cam_name, img in obs["pixels"].items(): @@ -76,13 +76,13 @@ def preprocess_observation(obs, device='cuda'): if img.shape[-2:] != (224, 224): img = F.pad(img, (0, 224 - img.shape[-1], 0, 224 - img.shape[-2])) processed_obs[f"observation.images.{cam_name}"] = img - + # Process state if "agent_pos" in obs: state = torch.from_numpy(obs["agent_pos"]).float() state = state.unsqueeze(0).to(device) processed_obs["observation.state"] = state - + return processed_obs ``` @@ -180,7 +180,7 @@ print("Original keys:", observation.keys()) print("Processed keys:", processed_obs.keys()) print("Image shape:", processed_obs["observation.images.camera_front"].shape) # [1, 3, 480, 640] print("Image dtype:", processed_obs["observation.images.camera_front"].dtype) # torch.float32 -print("Image range:", processed_obs["observation.images.camera_front"].min().item(), +print("Image range:", processed_obs["observation.images.camera_front"].min().item(), "to", processed_obs["observation.images.camera_front"].max().item()) # 0.0 to 1.0 ``` @@ -197,18 +197,18 @@ import torch.nn.functional as F @dataclass class ImagePadder: """Pad images to a standard size for batch processing.""" - + target_height: int = 224 target_width: int = 224 pad_value: float = 0.0 - + def __call__(self, transition: EnvTransition) -> EnvTransition: """Main processing method - required for all steps.""" obs = transition[TransitionIndex.OBSERVATION] - + if obs is None: return transition - + # Process all image observations for key in list(obs.keys()): if key.startswith("observation.images."): @@ -217,22 +217,22 @@ class ImagePadder: _, _, h, w = img.shape pad_h = max(0, self.target_height - h) pad_w = max(0, self.target_width - w) - + if pad_h > 0 or pad_w > 0: # Pad symmetrically pad_left = pad_w // 2 pad_right = pad_w - pad_left pad_top = pad_h // 2 pad_bottom = pad_h - pad_top - - img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), + + img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=self.pad_value) - + obs[key] = img - + # Return modified transition return (obs, *transition[1:]) - + def get_config(self) -> Dict[str, Any]: """Return JSON-serializable configuration - required for save/load.""" return { @@ -240,17 +240,17 @@ class ImagePadder: "target_width": self.target_width, "pad_value": self.pad_value } - + def state_dict(self) -> Dict[str, torch.Tensor]: """Return tensor state - only include torch.Tensor objects!""" # This step has no learnable parameters return {} - + def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: """Load tensor state - required if state_dict returns non-empty dict.""" # Nothing to load for this step pass - + def reset(self) -> None: """Reset internal state at episode boundaries - required for stateful steps.""" # This step is stateless, so nothing to reset @@ -265,15 +265,15 @@ These two methods serve different purposes and it's crucial to use them correctl @dataclass class AdaptiveNormalizer: """Example showing proper use of get_config and state_dict.""" - + learning_rate: float = 0.01 epsilon: float = 1e-8 - + def __post_init__(self): self.running_mean = None self.running_var = None self.num_samples = 0 # Python int, not tensor - + def get_config(self) -> Dict[str, Any]: """ONLY Python objects that can be JSON serialized!""" return { @@ -282,7 +282,7 @@ class AdaptiveNormalizer: "num_samples": self.num_samples, # int ✓ # "running_mean": self.running_mean, # torch.Tensor ✗ WRONG! } - + def state_dict(self) -> Dict[str, torch.Tensor]: """ONLY torch.Tensor objects!""" if self.running_mean is None: @@ -294,7 +294,7 @@ class AdaptiveNormalizer: # Instead, convert to tensor if needed: "num_samples_tensor": torch.tensor(self.num_samples) } - + def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: """Load tensors and convert back to Python types if needed.""" self.running_mean = state.get("running_mean") @@ -311,14 +311,14 @@ The `complementary_data` field is perfect for passing information between steps @dataclass class ImageStatisticsCalculator: """Calculate image statistics and pass to next steps.""" - + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {} - + if obs is None: return transition - + # Calculate statistics for all images image_stats = {} for key in obs: @@ -331,10 +331,10 @@ class ImageStatisticsCalculator: "max": img.max().item(), } image_stats[key] = stats - + # Store in complementary_data for next steps comp_data["image_statistics"] = image_stats - + # Return transition with updated complementary_data return ( obs, @@ -346,30 +346,30 @@ class ImageStatisticsCalculator: comp_data # Updated complementary_data ) -@dataclass +@dataclass class AdaptiveBrightnessAdjuster: """Adjust brightness based on statistics from previous step.""" - + target_brightness: float = 0.5 - + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {} - + if obs is None or "image_statistics" not in comp_data: return transition - + # Use statistics from previous step image_stats = comp_data["image_statistics"] - + for key in obs: if key.startswith("observation.images.") and key in image_stats: current_mean = image_stats[key]["mean"] brightness_adjust = self.target_brightness - current_mean - + # Adjust brightness obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1) - + return (obs, *transition[1:]) # Use them together @@ -394,28 +394,28 @@ import numpy as np @dataclass class DeviceMover: """Move all tensors to specified device.""" - + device: str = "cuda" - + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] - + if obs is None: return transition - + # Move all tensor observations to device for key, value in obs.items(): if isinstance(value, torch.Tensor): obs[key] = value.to(self.device) - + # Also handle action if present action = transition[TransitionIndex.ACTION] if action is not None and isinstance(action, torch.Tensor): action = action.to(self.device) return (obs, action, *transition[2:]) - + return (obs, *transition[1:]) - + def get_config(self) -> Dict[str, Any]: return {"device": str(self.device)} @@ -423,52 +423,52 @@ class DeviceMover: @dataclass class RunningNormalizer: """Normalize using running statistics with proper device handling.""" - + feature_dim: int momentum: float = 0.1 - + def __post_init__(self): # Initialize as None - will be created on first call with correct device self.running_mean = None self.running_var = None self.initialized = False - + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] - + if obs is None or "observation.state" not in obs: return transition - + state = obs["observation.state"] - + # Initialize on first call with correct device if not self.initialized: device = state.device self.running_mean = torch.zeros(self.feature_dim, device=device) self.running_var = torch.ones(self.feature_dim, device=device) self.initialized = True - + # Update statistics with torch.no_grad(): batch_mean = state.mean(dim=0) batch_var = state.var(dim=0, unbiased=False) - + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var - + # Normalize state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt() obs["observation.state"] = state_normalized - + return (obs, *transition[1:]) - + def get_config(self) -> Dict[str, Any]: return { "feature_dim": self.feature_dim, "momentum": self.momentum, "initialized": self.initialized } - + def state_dict(self) -> Dict[str, torch.Tensor]: if not self.initialized: return {} @@ -476,13 +476,13 @@ class RunningNormalizer: "running_mean": self.running_mean, "running_var": self.running_var } - + def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: if state: self.running_mean = state["running_mean"] self.running_var = state["running_var"] self.initialized = True - + def reset(self) -> None: # Don't reset statistics - they persist across episodes pass @@ -524,29 +524,29 @@ When your environment and policy speak different "languages": @dataclass class KeyRemapper: """Rename observation keys to match policy expectations.""" - + key_mapping: Dict[str, str] = field(default_factory=lambda: { "rgb_camera_front": "observation.images.wrist", "joint_positions": "observation.state", "gripper_state": "observation.gripper" }) - + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] if obs is None: return transition - + # Create new observation with renamed keys renamed_obs = {} for old_key, new_key in self.key_mapping.items(): if old_key in obs: renamed_obs[new_key] = obs[old_key] - + # Keep any unmapped keys as-is for key, value in obs.items(): if key not in self.key_mapping: renamed_obs[key] = value - + return (renamed_obs, *transition[1:]) ``` @@ -559,15 +559,15 @@ When you need to crop and resize images to focus on the manipulation workspace: @dataclass class WorkspaceCropper: """Crop and resize images to focus on robot workspace.""" - + crop_bbox: Tuple[int, int, int, int] = (400, 200, 1200, 800) # (x1, y1, x2, y2) output_size: Tuple[int, int] = (224, 224) - + def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] if obs is None: return transition - + for key in list(obs.keys()): if key.startswith("observation.images."): img = obs[key] @@ -576,13 +576,13 @@ class WorkspaceCropper: img_cropped = img[:, :, y1:y2, x1:x2] # Resize to expected dimensions img_resized = F.interpolate( - img_cropped, - size=self.output_size, - mode='bilinear', + img_cropped, + size=self.output_size, + mode='bilinear', align_corners=False ) obs[key] = img_resized - + return (obs, *transition[1:]) ``` @@ -596,7 +596,7 @@ Now you can compose these steps into robot-specific pipelines: robot_a_processor = RobotProcessor([ # Observation preprocessing KeyRemapper({ - "wrist_rgb": "observation.images.wrist", + "wrist_rgb": "observation.images.wrist", "arm_joints": "observation.state" }), ImageProcessor(), @@ -604,7 +604,7 @@ robot_a_processor = RobotProcessor([ StateProcessor(), VelocityCalculator(state_key="observation.state"), DeviceMover("cuda"), - + # Action postprocessing ActionSmoother(alpha=0.2), ], name="RobotA_ACT_Processor") @@ -618,10 +618,10 @@ robot_b_processor = RobotProcessor([ }), ImageProcessor(), WorkspaceCropper(crop_bbox=(100, 50, 1100, 950), output_size=(224, 224)), - StateProcessor(), + StateProcessor(), VelocityCalculator(state_key="observation.state"), DeviceMover("cuda"), - + ActionSmoother(alpha=0.3), # Different smoothing ], name="RobotB_ACT_Processor") @@ -645,11 +645,11 @@ The beauty of this approach is that: ```python def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] - + # Always check if observation exists if obs is None: return transition - + # Also check for specific keys if "observation.state" not in obs: return transition @@ -668,11 +668,11 @@ return (modified_obs, None, 0.0, False, False, {}, {}) ```python def __call__(self, transition: EnvTransition) -> EnvTransition: obs = transition[TransitionIndex.OBSERVATION] - + if self.store_previous: # Good - clone to avoid reference issues self.previous_state = obs["observation.state"].clone() - + # Bad - stores reference that might be modified # self.previous_state = obs["observation.state"] ``` @@ -682,7 +682,7 @@ def __call__(self, transition: EnvTransition) -> EnvTransition: def state_dict(self) -> Dict[str, torch.Tensor]: if self.buffer is None: return {} - + # Good - clone to avoid memory sharing issues return {"buffer": self.buffer.clone()} ``` @@ -713,16 +713,16 @@ class ActionClipper: """Clip actions to safe ranges.""" min_value: float = -1.0 max_value: float = 1.0 - + def __call__(self, transition: EnvTransition) -> EnvTransition: action = transition[TransitionIndex.ACTION] - + if action is not None: action = torch.clamp(action, self.min_value, self.max_value) return (transition[TransitionIndex.OBSERVATION], action, *transition[2:]) - + return transition - + def get_config(self) -> Dict[str, Any]: return {"min_value": self.min_value, "max_value": self.max_value} @@ -736,7 +736,7 @@ policy = ACTPolicy.from_pretrained("lerobot/act_aloha_sim_transfer_cube_human") # Move everything to GPU preprocessor = preprocessor.to("cuda") -postprocessor = postprocessor.to("cuda") +postprocessor = postprocessor.to("cuda") policy = policy.to("cuda") # Control loop @@ -745,42 +745,42 @@ obs, info = env.reset() for episode in range(10): print(f"Episode {episode + 1}") - + for step in range(1000): # Create transition with raw observation transition = (obs, None, 0.0, False, False, info, {"step": step}) - + # Preprocess processed_transition = preprocessor(transition) processed_obs = processed_transition[TransitionIndex.OBSERVATION] - + # Get action from policy with torch.no_grad(): action = policy.select_action(processed_obs) - + # Postprocess action action_transition = ( - processed_obs, - action, - 0.0, - False, - False, - info, + processed_obs, + action, + 0.0, + False, + False, + info, {"raw_action": action.clone()} # Store raw action in complementary_data ) processed_action_transition = postprocessor(action_transition) final_action = processed_action_transition[TransitionIndex.ACTION] - + # Execute action obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy()) - + if terminated or truncated: # Reset at episode boundary preprocessor.reset() postprocessor.reset() obs, info = env.reset() break - + # Save preprocessor with learned statistics preprocessor.save_pretrained(f"./checkpoints/preprocessor_ep{episode}") @@ -844,4 +844,4 @@ RobotProcessor provides a powerful, modular approach to data preprocessing in ro By following these patterns, your preprocessing code becomes more maintainable, shareable, and robust. -For the full API reference, see the [RobotProcessor API documentation](/api/processor). \ No newline at end of file +For the full API reference, see the [RobotProcessor API documentation](/api/processor).