mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
cd13f1ecfd
commit
ac742c9f0d
@@ -60,7 +60,7 @@ If you've worked with robot learning before, you've likely written preprocessing
|
||||
# Traditional procedural approach - hard to maintain and share
|
||||
def preprocess_observation(obs, device='cuda'):
|
||||
processed_obs = {}
|
||||
|
||||
|
||||
# Process images
|
||||
if "pixels" in obs:
|
||||
for cam_name, img in obs["pixels"].items():
|
||||
@@ -76,13 +76,13 @@ def preprocess_observation(obs, device='cuda'):
|
||||
if img.shape[-2:] != (224, 224):
|
||||
img = F.pad(img, (0, 224 - img.shape[-1], 0, 224 - img.shape[-2]))
|
||||
processed_obs[f"observation.images.{cam_name}"] = img
|
||||
|
||||
|
||||
# Process state
|
||||
if "agent_pos" in obs:
|
||||
state = torch.from_numpy(obs["agent_pos"]).float()
|
||||
state = state.unsqueeze(0).to(device)
|
||||
processed_obs["observation.state"] = state
|
||||
|
||||
|
||||
return processed_obs
|
||||
```
|
||||
|
||||
@@ -180,7 +180,7 @@ print("Original keys:", observation.keys())
|
||||
print("Processed keys:", processed_obs.keys())
|
||||
print("Image shape:", processed_obs["observation.images.camera_front"].shape) # [1, 3, 480, 640]
|
||||
print("Image dtype:", processed_obs["observation.images.camera_front"].dtype) # torch.float32
|
||||
print("Image range:", processed_obs["observation.images.camera_front"].min().item(),
|
||||
print("Image range:", processed_obs["observation.images.camera_front"].min().item(),
|
||||
"to", processed_obs["observation.images.camera_front"].max().item()) # 0.0 to 1.0
|
||||
```
|
||||
|
||||
@@ -197,18 +197,18 @@ import torch.nn.functional as F
|
||||
@dataclass
|
||||
class ImagePadder:
|
||||
"""Pad images to a standard size for batch processing."""
|
||||
|
||||
|
||||
target_height: int = 224
|
||||
target_width: int = 224
|
||||
pad_value: float = 0.0
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Main processing method - required for all steps."""
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
|
||||
# Process all image observations
|
||||
for key in list(obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
@@ -217,22 +217,22 @@ class ImagePadder:
|
||||
_, _, h, w = img.shape
|
||||
pad_h = max(0, self.target_height - h)
|
||||
pad_w = max(0, self.target_width - w)
|
||||
|
||||
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
# Pad symmetrically
|
||||
pad_left = pad_w // 2
|
||||
pad_right = pad_w - pad_left
|
||||
pad_top = pad_h // 2
|
||||
pad_bottom = pad_h - pad_top
|
||||
|
||||
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
|
||||
|
||||
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
|
||||
mode='constant', value=self.pad_value)
|
||||
|
||||
|
||||
obs[key] = img
|
||||
|
||||
|
||||
# Return modified transition
|
||||
return (obs, *transition[1:])
|
||||
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Return JSON-serializable configuration - required for save/load."""
|
||||
return {
|
||||
@@ -240,17 +240,17 @@ class ImagePadder:
|
||||
"target_width": self.target_width,
|
||||
"pad_value": self.pad_value
|
||||
}
|
||||
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
"""Return tensor state - only include torch.Tensor objects!"""
|
||||
# This step has no learnable parameters
|
||||
return {}
|
||||
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state - required if state_dict returns non-empty dict."""
|
||||
# Nothing to load for this step
|
||||
pass
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state at episode boundaries - required for stateful steps."""
|
||||
# This step is stateless, so nothing to reset
|
||||
@@ -265,15 +265,15 @@ These two methods serve different purposes and it's crucial to use them correctl
|
||||
@dataclass
|
||||
class AdaptiveNormalizer:
|
||||
"""Example showing proper use of get_config and state_dict."""
|
||||
|
||||
|
||||
learning_rate: float = 0.01
|
||||
epsilon: float = 1e-8
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
self.running_mean = None
|
||||
self.running_var = None
|
||||
self.num_samples = 0 # Python int, not tensor
|
||||
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""ONLY Python objects that can be JSON serialized!"""
|
||||
return {
|
||||
@@ -282,7 +282,7 @@ class AdaptiveNormalizer:
|
||||
"num_samples": self.num_samples, # int ✓
|
||||
# "running_mean": self.running_mean, # torch.Tensor ✗ WRONG!
|
||||
}
|
||||
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
"""ONLY torch.Tensor objects!"""
|
||||
if self.running_mean is None:
|
||||
@@ -294,7 +294,7 @@ class AdaptiveNormalizer:
|
||||
# Instead, convert to tensor if needed:
|
||||
"num_samples_tensor": torch.tensor(self.num_samples)
|
||||
}
|
||||
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensors and convert back to Python types if needed."""
|
||||
self.running_mean = state.get("running_mean")
|
||||
@@ -311,14 +311,14 @@ The `complementary_data` field is perfect for passing information between steps
|
||||
@dataclass
|
||||
class ImageStatisticsCalculator:
|
||||
"""Calculate image statistics and pass to next steps."""
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
|
||||
# Calculate statistics for all images
|
||||
image_stats = {}
|
||||
for key in obs:
|
||||
@@ -331,10 +331,10 @@ class ImageStatisticsCalculator:
|
||||
"max": img.max().item(),
|
||||
}
|
||||
image_stats[key] = stats
|
||||
|
||||
|
||||
# Store in complementary_data for next steps
|
||||
comp_data["image_statistics"] = image_stats
|
||||
|
||||
|
||||
# Return transition with updated complementary_data
|
||||
return (
|
||||
obs,
|
||||
@@ -346,30 +346,30 @@ class ImageStatisticsCalculator:
|
||||
comp_data # Updated complementary_data
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class AdaptiveBrightnessAdjuster:
|
||||
"""Adjust brightness based on statistics from previous step."""
|
||||
|
||||
|
||||
target_brightness: float = 0.5
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||
|
||||
|
||||
if obs is None or "image_statistics" not in comp_data:
|
||||
return transition
|
||||
|
||||
|
||||
# Use statistics from previous step
|
||||
image_stats = comp_data["image_statistics"]
|
||||
|
||||
|
||||
for key in obs:
|
||||
if key.startswith("observation.images.") and key in image_stats:
|
||||
current_mean = image_stats[key]["mean"]
|
||||
brightness_adjust = self.target_brightness - current_mean
|
||||
|
||||
|
||||
# Adjust brightness
|
||||
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
|
||||
|
||||
|
||||
return (obs, *transition[1:])
|
||||
|
||||
# Use them together
|
||||
@@ -394,28 +394,28 @@ import numpy as np
|
||||
@dataclass
|
||||
class DeviceMover:
|
||||
"""Move all tensors to specified device."""
|
||||
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
|
||||
# Move all tensor observations to device
|
||||
for key, value in obs.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
obs[key] = value.to(self.device)
|
||||
|
||||
|
||||
# Also handle action if present
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
action = action.to(self.device)
|
||||
return (obs, action, *transition[2:])
|
||||
|
||||
|
||||
return (obs, *transition[1:])
|
||||
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"device": str(self.device)}
|
||||
|
||||
@@ -423,52 +423,52 @@ class DeviceMover:
|
||||
@dataclass
|
||||
class RunningNormalizer:
|
||||
"""Normalize using running statistics with proper device handling."""
|
||||
|
||||
|
||||
feature_dim: int
|
||||
momentum: float = 0.1
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize as None - will be created on first call with correct device
|
||||
self.running_mean = None
|
||||
self.running_var = None
|
||||
self.initialized = False
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
if obs is None or "observation.state" not in obs:
|
||||
return transition
|
||||
|
||||
|
||||
state = obs["observation.state"]
|
||||
|
||||
|
||||
# Initialize on first call with correct device
|
||||
if not self.initialized:
|
||||
device = state.device
|
||||
self.running_mean = torch.zeros(self.feature_dim, device=device)
|
||||
self.running_var = torch.ones(self.feature_dim, device=device)
|
||||
self.initialized = True
|
||||
|
||||
|
||||
# Update statistics
|
||||
with torch.no_grad():
|
||||
batch_mean = state.mean(dim=0)
|
||||
batch_var = state.var(dim=0, unbiased=False)
|
||||
|
||||
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
|
||||
|
||||
|
||||
# Normalize
|
||||
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
|
||||
obs["observation.state"] = state_normalized
|
||||
|
||||
|
||||
return (obs, *transition[1:])
|
||||
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"feature_dim": self.feature_dim,
|
||||
"momentum": self.momentum,
|
||||
"initialized": self.initialized
|
||||
}
|
||||
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
if not self.initialized:
|
||||
return {}
|
||||
@@ -476,13 +476,13 @@ class RunningNormalizer:
|
||||
"running_mean": self.running_mean,
|
||||
"running_var": self.running_var
|
||||
}
|
||||
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||
if state:
|
||||
self.running_mean = state["running_mean"]
|
||||
self.running_var = state["running_var"]
|
||||
self.initialized = True
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
# Don't reset statistics - they persist across episodes
|
||||
pass
|
||||
@@ -524,29 +524,29 @@ When your environment and policy speak different "languages":
|
||||
@dataclass
|
||||
class KeyRemapper:
|
||||
"""Rename observation keys to match policy expectations."""
|
||||
|
||||
|
||||
key_mapping: Dict[str, str] = field(default_factory=lambda: {
|
||||
"rgb_camera_front": "observation.images.wrist",
|
||||
"joint_positions": "observation.state",
|
||||
"gripper_state": "observation.gripper"
|
||||
})
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
|
||||
# Create new observation with renamed keys
|
||||
renamed_obs = {}
|
||||
for old_key, new_key in self.key_mapping.items():
|
||||
if old_key in obs:
|
||||
renamed_obs[new_key] = obs[old_key]
|
||||
|
||||
|
||||
# Keep any unmapped keys as-is
|
||||
for key, value in obs.items():
|
||||
if key not in self.key_mapping:
|
||||
renamed_obs[key] = value
|
||||
|
||||
|
||||
return (renamed_obs, *transition[1:])
|
||||
```
|
||||
|
||||
@@ -559,15 +559,15 @@ When you need to crop and resize images to focus on the manipulation workspace:
|
||||
@dataclass
|
||||
class WorkspaceCropper:
|
||||
"""Crop and resize images to focus on robot workspace."""
|
||||
|
||||
|
||||
crop_bbox: Tuple[int, int, int, int] = (400, 200, 1200, 800) # (x1, y1, x2, y2)
|
||||
output_size: Tuple[int, int] = (224, 224)
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
|
||||
for key in list(obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
img = obs[key]
|
||||
@@ -576,13 +576,13 @@ class WorkspaceCropper:
|
||||
img_cropped = img[:, :, y1:y2, x1:x2]
|
||||
# Resize to expected dimensions
|
||||
img_resized = F.interpolate(
|
||||
img_cropped,
|
||||
size=self.output_size,
|
||||
mode='bilinear',
|
||||
img_cropped,
|
||||
size=self.output_size,
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
obs[key] = img_resized
|
||||
|
||||
|
||||
return (obs, *transition[1:])
|
||||
```
|
||||
|
||||
@@ -596,7 +596,7 @@ Now you can compose these steps into robot-specific pipelines:
|
||||
robot_a_processor = RobotProcessor([
|
||||
# Observation preprocessing
|
||||
KeyRemapper({
|
||||
"wrist_rgb": "observation.images.wrist",
|
||||
"wrist_rgb": "observation.images.wrist",
|
||||
"arm_joints": "observation.state"
|
||||
}),
|
||||
ImageProcessor(),
|
||||
@@ -604,7 +604,7 @@ robot_a_processor = RobotProcessor([
|
||||
StateProcessor(),
|
||||
VelocityCalculator(state_key="observation.state"),
|
||||
DeviceMover("cuda"),
|
||||
|
||||
|
||||
# Action postprocessing
|
||||
ActionSmoother(alpha=0.2),
|
||||
], name="RobotA_ACT_Processor")
|
||||
@@ -618,10 +618,10 @@ robot_b_processor = RobotProcessor([
|
||||
}),
|
||||
ImageProcessor(),
|
||||
WorkspaceCropper(crop_bbox=(100, 50, 1100, 950), output_size=(224, 224)),
|
||||
StateProcessor(),
|
||||
StateProcessor(),
|
||||
VelocityCalculator(state_key="observation.state"),
|
||||
DeviceMover("cuda"),
|
||||
|
||||
|
||||
ActionSmoother(alpha=0.3), # Different smoothing
|
||||
], name="RobotB_ACT_Processor")
|
||||
|
||||
@@ -645,11 +645,11 @@ The beauty of this approach is that:
|
||||
```python
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
# Always check if observation exists
|
||||
if obs is None:
|
||||
return transition
|
||||
|
||||
|
||||
# Also check for specific keys
|
||||
if "observation.state" not in obs:
|
||||
return transition
|
||||
@@ -668,11 +668,11 @@ return (modified_obs, None, 0.0, False, False, {}, {})
|
||||
```python
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
if self.store_previous:
|
||||
# Good - clone to avoid reference issues
|
||||
self.previous_state = obs["observation.state"].clone()
|
||||
|
||||
|
||||
# Bad - stores reference that might be modified
|
||||
# self.previous_state = obs["observation.state"]
|
||||
```
|
||||
@@ -682,7 +682,7 @@ def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
if self.buffer is None:
|
||||
return {}
|
||||
|
||||
|
||||
# Good - clone to avoid memory sharing issues
|
||||
return {"buffer": self.buffer.clone()}
|
||||
```
|
||||
@@ -713,16 +713,16 @@ class ActionClipper:
|
||||
"""Clip actions to safe ranges."""
|
||||
min_value: float = -1.0
|
||||
max_value: float = 1.0
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
|
||||
|
||||
if action is not None:
|
||||
action = torch.clamp(action, self.min_value, self.max_value)
|
||||
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
|
||||
|
||||
|
||||
return transition
|
||||
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return {"min_value": self.min_value, "max_value": self.max_value}
|
||||
|
||||
@@ -736,7 +736,7 @@ policy = ACTPolicy.from_pretrained("lerobot/act_aloha_sim_transfer_cube_human")
|
||||
|
||||
# Move everything to GPU
|
||||
preprocessor = preprocessor.to("cuda")
|
||||
postprocessor = postprocessor.to("cuda")
|
||||
postprocessor = postprocessor.to("cuda")
|
||||
policy = policy.to("cuda")
|
||||
|
||||
# Control loop
|
||||
@@ -745,42 +745,42 @@ obs, info = env.reset()
|
||||
|
||||
for episode in range(10):
|
||||
print(f"Episode {episode + 1}")
|
||||
|
||||
|
||||
for step in range(1000):
|
||||
# Create transition with raw observation
|
||||
transition = (obs, None, 0.0, False, False, info, {"step": step})
|
||||
|
||||
|
||||
# Preprocess
|
||||
processed_transition = preprocessor(transition)
|
||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
# Get action from policy
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(processed_obs)
|
||||
|
||||
|
||||
# Postprocess action
|
||||
action_transition = (
|
||||
processed_obs,
|
||||
action,
|
||||
0.0,
|
||||
False,
|
||||
False,
|
||||
info,
|
||||
processed_obs,
|
||||
action,
|
||||
0.0,
|
||||
False,
|
||||
False,
|
||||
info,
|
||||
{"raw_action": action.clone()} # Store raw action in complementary_data
|
||||
)
|
||||
processed_action_transition = postprocessor(action_transition)
|
||||
final_action = processed_action_transition[TransitionIndex.ACTION]
|
||||
|
||||
|
||||
# Execute action
|
||||
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
|
||||
|
||||
|
||||
if terminated or truncated:
|
||||
# Reset at episode boundary
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
obs, info = env.reset()
|
||||
break
|
||||
|
||||
|
||||
# Save preprocessor with learned statistics
|
||||
preprocessor.save_pretrained(f"./checkpoints/preprocessor_ep{episode}")
|
||||
|
||||
@@ -844,4 +844,4 @@ RobotProcessor provides a powerful, modular approach to data preprocessing in ro
|
||||
|
||||
By following these patterns, your preprocessing code becomes more maintainable, shareable, and robust.
|
||||
|
||||
For the full API reference, see the [RobotProcessor API documentation](/api/processor).
|
||||
For the full API reference, see the [RobotProcessor API documentation](/api/processor).
|
||||
|
||||
Reference in New Issue
Block a user