[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-03 13:22:41 +00:00
committed by Adil Zouitine
parent cd13f1ecfd
commit ac742c9f0d
+96 -96
View File
@@ -60,7 +60,7 @@ If you've worked with robot learning before, you've likely written preprocessing
# Traditional procedural approach - hard to maintain and share
def preprocess_observation(obs, device='cuda'):
processed_obs = {}
# Process images
if "pixels" in obs:
for cam_name, img in obs["pixels"].items():
@@ -76,13 +76,13 @@ def preprocess_observation(obs, device='cuda'):
if img.shape[-2:] != (224, 224):
img = F.pad(img, (0, 224 - img.shape[-1], 0, 224 - img.shape[-2]))
processed_obs[f"observation.images.{cam_name}"] = img
# Process state
if "agent_pos" in obs:
state = torch.from_numpy(obs["agent_pos"]).float()
state = state.unsqueeze(0).to(device)
processed_obs["observation.state"] = state
return processed_obs
```
@@ -180,7 +180,7 @@ print("Original keys:", observation.keys())
print("Processed keys:", processed_obs.keys())
print("Image shape:", processed_obs["observation.images.camera_front"].shape) # [1, 3, 480, 640]
print("Image dtype:", processed_obs["observation.images.camera_front"].dtype) # torch.float32
print("Image range:", processed_obs["observation.images.camera_front"].min().item(),
print("Image range:", processed_obs["observation.images.camera_front"].min().item(),
"to", processed_obs["observation.images.camera_front"].max().item()) # 0.0 to 1.0
```
@@ -197,18 +197,18 @@ import torch.nn.functional as F
@dataclass
class ImagePadder:
"""Pad images to a standard size for batch processing."""
target_height: int = 224
target_width: int = 224
pad_value: float = 0.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Main processing method - required for all steps."""
obs = transition[TransitionIndex.OBSERVATION]
if obs is None:
return transition
# Process all image observations
for key in list(obs.keys()):
if key.startswith("observation.images."):
@@ -217,22 +217,22 @@ class ImagePadder:
_, _, h, w = img.shape
pad_h = max(0, self.target_height - h)
pad_w = max(0, self.target_width - w)
if pad_h > 0 or pad_w > 0:
# Pad symmetrically
pad_left = pad_w // 2
pad_right = pad_w - pad_left
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
mode='constant', value=self.pad_value)
obs[key] = img
# Return modified transition
return (obs, *transition[1:])
def get_config(self) -> Dict[str, Any]:
"""Return JSON-serializable configuration - required for save/load."""
return {
@@ -240,17 +240,17 @@ class ImagePadder:
"target_width": self.target_width,
"pad_value": self.pad_value
}
def state_dict(self) -> Dict[str, torch.Tensor]:
"""Return tensor state - only include torch.Tensor objects!"""
# This step has no learnable parameters
return {}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
"""Load tensor state - required if state_dict returns non-empty dict."""
# Nothing to load for this step
pass
def reset(self) -> None:
"""Reset internal state at episode boundaries - required for stateful steps."""
# This step is stateless, so nothing to reset
@@ -265,15 +265,15 @@ These two methods serve different purposes and it's crucial to use them correctl
@dataclass
class AdaptiveNormalizer:
"""Example showing proper use of get_config and state_dict."""
learning_rate: float = 0.01
epsilon: float = 1e-8
def __post_init__(self):
self.running_mean = None
self.running_var = None
self.num_samples = 0 # Python int, not tensor
def get_config(self) -> Dict[str, Any]:
"""ONLY Python objects that can be JSON serialized!"""
return {
@@ -282,7 +282,7 @@ class AdaptiveNormalizer:
"num_samples": self.num_samples, # int ✓
# "running_mean": self.running_mean, # torch.Tensor ✗ WRONG!
}
def state_dict(self) -> Dict[str, torch.Tensor]:
"""ONLY torch.Tensor objects!"""
if self.running_mean is None:
@@ -294,7 +294,7 @@ class AdaptiveNormalizer:
# Instead, convert to tensor if needed:
"num_samples_tensor": torch.tensor(self.num_samples)
}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
"""Load tensors and convert back to Python types if needed."""
self.running_mean = state.get("running_mean")
@@ -311,14 +311,14 @@ The `complementary_data` field is perfect for passing information between steps
@dataclass
class ImageStatisticsCalculator:
"""Calculate image statistics and pass to next steps."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
if obs is None:
return transition
# Calculate statistics for all images
image_stats = {}
for key in obs:
@@ -331,10 +331,10 @@ class ImageStatisticsCalculator:
"max": img.max().item(),
}
image_stats[key] = stats
# Store in complementary_data for next steps
comp_data["image_statistics"] = image_stats
# Return transition with updated complementary_data
return (
obs,
@@ -346,30 +346,30 @@ class ImageStatisticsCalculator:
comp_data # Updated complementary_data
)
@dataclass
@dataclass
class AdaptiveBrightnessAdjuster:
"""Adjust brightness based on statistics from previous step."""
target_brightness: float = 0.5
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
if obs is None or "image_statistics" not in comp_data:
return transition
# Use statistics from previous step
image_stats = comp_data["image_statistics"]
for key in obs:
if key.startswith("observation.images.") and key in image_stats:
current_mean = image_stats[key]["mean"]
brightness_adjust = self.target_brightness - current_mean
# Adjust brightness
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
return (obs, *transition[1:])
# Use them together
@@ -394,28 +394,28 @@ import numpy as np
@dataclass
class DeviceMover:
"""Move all tensors to specified device."""
device: str = "cuda"
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
if obs is None:
return transition
# Move all tensor observations to device
for key, value in obs.items():
if isinstance(value, torch.Tensor):
obs[key] = value.to(self.device)
# Also handle action if present
action = transition[TransitionIndex.ACTION]
if action is not None and isinstance(action, torch.Tensor):
action = action.to(self.device)
return (obs, action, *transition[2:])
return (obs, *transition[1:])
def get_config(self) -> Dict[str, Any]:
return {"device": str(self.device)}
@@ -423,52 +423,52 @@ class DeviceMover:
@dataclass
class RunningNormalizer:
"""Normalize using running statistics with proper device handling."""
feature_dim: int
momentum: float = 0.1
def __post_init__(self):
# Initialize as None - will be created on first call with correct device
self.running_mean = None
self.running_var = None
self.initialized = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
if obs is None or "observation.state" not in obs:
return transition
state = obs["observation.state"]
# Initialize on first call with correct device
if not self.initialized:
device = state.device
self.running_mean = torch.zeros(self.feature_dim, device=device)
self.running_var = torch.ones(self.feature_dim, device=device)
self.initialized = True
# Update statistics
with torch.no_grad():
batch_mean = state.mean(dim=0)
batch_var = state.var(dim=0, unbiased=False)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
# Normalize
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
obs["observation.state"] = state_normalized
return (obs, *transition[1:])
def get_config(self) -> Dict[str, Any]:
return {
"feature_dim": self.feature_dim,
"momentum": self.momentum,
"initialized": self.initialized
}
def state_dict(self) -> Dict[str, torch.Tensor]:
if not self.initialized:
return {}
@@ -476,13 +476,13 @@ class RunningNormalizer:
"running_mean": self.running_mean,
"running_var": self.running_var
}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
if state:
self.running_mean = state["running_mean"]
self.running_var = state["running_var"]
self.initialized = True
def reset(self) -> None:
# Don't reset statistics - they persist across episodes
pass
@@ -524,29 +524,29 @@ When your environment and policy speak different "languages":
@dataclass
class KeyRemapper:
"""Rename observation keys to match policy expectations."""
key_mapping: Dict[str, str] = field(default_factory=lambda: {
"rgb_camera_front": "observation.images.wrist",
"joint_positions": "observation.state",
"gripper_state": "observation.gripper"
})
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
if obs is None:
return transition
# Create new observation with renamed keys
renamed_obs = {}
for old_key, new_key in self.key_mapping.items():
if old_key in obs:
renamed_obs[new_key] = obs[old_key]
# Keep any unmapped keys as-is
for key, value in obs.items():
if key not in self.key_mapping:
renamed_obs[key] = value
return (renamed_obs, *transition[1:])
```
@@ -559,15 +559,15 @@ When you need to crop and resize images to focus on the manipulation workspace:
@dataclass
class WorkspaceCropper:
"""Crop and resize images to focus on robot workspace."""
crop_bbox: Tuple[int, int, int, int] = (400, 200, 1200, 800) # (x1, y1, x2, y2)
output_size: Tuple[int, int] = (224, 224)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
if obs is None:
return transition
for key in list(obs.keys()):
if key.startswith("observation.images."):
img = obs[key]
@@ -576,13 +576,13 @@ class WorkspaceCropper:
img_cropped = img[:, :, y1:y2, x1:x2]
# Resize to expected dimensions
img_resized = F.interpolate(
img_cropped,
size=self.output_size,
mode='bilinear',
img_cropped,
size=self.output_size,
mode='bilinear',
align_corners=False
)
obs[key] = img_resized
return (obs, *transition[1:])
```
@@ -596,7 +596,7 @@ Now you can compose these steps into robot-specific pipelines:
robot_a_processor = RobotProcessor([
# Observation preprocessing
KeyRemapper({
"wrist_rgb": "observation.images.wrist",
"wrist_rgb": "observation.images.wrist",
"arm_joints": "observation.state"
}),
ImageProcessor(),
@@ -604,7 +604,7 @@ robot_a_processor = RobotProcessor([
StateProcessor(),
VelocityCalculator(state_key="observation.state"),
DeviceMover("cuda"),
# Action postprocessing
ActionSmoother(alpha=0.2),
], name="RobotA_ACT_Processor")
@@ -618,10 +618,10 @@ robot_b_processor = RobotProcessor([
}),
ImageProcessor(),
WorkspaceCropper(crop_bbox=(100, 50, 1100, 950), output_size=(224, 224)),
StateProcessor(),
StateProcessor(),
VelocityCalculator(state_key="observation.state"),
DeviceMover("cuda"),
ActionSmoother(alpha=0.3), # Different smoothing
], name="RobotB_ACT_Processor")
@@ -645,11 +645,11 @@ The beauty of this approach is that:
```python
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
# Always check if observation exists
if obs is None:
return transition
# Also check for specific keys
if "observation.state" not in obs:
return transition
@@ -668,11 +668,11 @@ return (modified_obs, None, 0.0, False, False, {}, {})
```python
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION]
if self.store_previous:
# Good - clone to avoid reference issues
self.previous_state = obs["observation.state"].clone()
# Bad - stores reference that might be modified
# self.previous_state = obs["observation.state"]
```
@@ -682,7 +682,7 @@ def __call__(self, transition: EnvTransition) -> EnvTransition:
def state_dict(self) -> Dict[str, torch.Tensor]:
if self.buffer is None:
return {}
# Good - clone to avoid memory sharing issues
return {"buffer": self.buffer.clone()}
```
@@ -713,16 +713,16 @@ class ActionClipper:
"""Clip actions to safe ranges."""
min_value: float = -1.0
max_value: float = 1.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition[TransitionIndex.ACTION]
if action is not None:
action = torch.clamp(action, self.min_value, self.max_value)
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
return transition
def get_config(self) -> Dict[str, Any]:
return {"min_value": self.min_value, "max_value": self.max_value}
@@ -736,7 +736,7 @@ policy = ACTPolicy.from_pretrained("lerobot/act_aloha_sim_transfer_cube_human")
# Move everything to GPU
preprocessor = preprocessor.to("cuda")
postprocessor = postprocessor.to("cuda")
postprocessor = postprocessor.to("cuda")
policy = policy.to("cuda")
# Control loop
@@ -745,42 +745,42 @@ obs, info = env.reset()
for episode in range(10):
print(f"Episode {episode + 1}")
for step in range(1000):
# Create transition with raw observation
transition = (obs, None, 0.0, False, False, info, {"step": step})
# Preprocess
processed_transition = preprocessor(transition)
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
# Get action from policy
with torch.no_grad():
action = policy.select_action(processed_obs)
# Postprocess action
action_transition = (
processed_obs,
action,
0.0,
False,
False,
info,
processed_obs,
action,
0.0,
False,
False,
info,
{"raw_action": action.clone()} # Store raw action in complementary_data
)
processed_action_transition = postprocessor(action_transition)
final_action = processed_action_transition[TransitionIndex.ACTION]
# Execute action
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
if terminated or truncated:
# Reset at episode boundary
preprocessor.reset()
postprocessor.reset()
obs, info = env.reset()
break
# Save preprocessor with learned statistics
preprocessor.save_pretrained(f"./checkpoints/preprocessor_ep{episode}")
@@ -844,4 +844,4 @@ RobotProcessor provides a powerful, modular approach to data preprocessing in ro
By following these patterns, your preprocessing code becomes more maintainable, shareable, and robust.
For the full API reference, see the [RobotProcessor API documentation](/api/processor).
For the full API reference, see the [RobotProcessor API documentation](/api/processor).