mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
cd13f1ecfd
commit
ac742c9f0d
@@ -60,7 +60,7 @@ If you've worked with robot learning before, you've likely written preprocessing
|
|||||||
# Traditional procedural approach - hard to maintain and share
|
# Traditional procedural approach - hard to maintain and share
|
||||||
def preprocess_observation(obs, device='cuda'):
|
def preprocess_observation(obs, device='cuda'):
|
||||||
processed_obs = {}
|
processed_obs = {}
|
||||||
|
|
||||||
# Process images
|
# Process images
|
||||||
if "pixels" in obs:
|
if "pixels" in obs:
|
||||||
for cam_name, img in obs["pixels"].items():
|
for cam_name, img in obs["pixels"].items():
|
||||||
@@ -76,13 +76,13 @@ def preprocess_observation(obs, device='cuda'):
|
|||||||
if img.shape[-2:] != (224, 224):
|
if img.shape[-2:] != (224, 224):
|
||||||
img = F.pad(img, (0, 224 - img.shape[-1], 0, 224 - img.shape[-2]))
|
img = F.pad(img, (0, 224 - img.shape[-1], 0, 224 - img.shape[-2]))
|
||||||
processed_obs[f"observation.images.{cam_name}"] = img
|
processed_obs[f"observation.images.{cam_name}"] = img
|
||||||
|
|
||||||
# Process state
|
# Process state
|
||||||
if "agent_pos" in obs:
|
if "agent_pos" in obs:
|
||||||
state = torch.from_numpy(obs["agent_pos"]).float()
|
state = torch.from_numpy(obs["agent_pos"]).float()
|
||||||
state = state.unsqueeze(0).to(device)
|
state = state.unsqueeze(0).to(device)
|
||||||
processed_obs["observation.state"] = state
|
processed_obs["observation.state"] = state
|
||||||
|
|
||||||
return processed_obs
|
return processed_obs
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ print("Original keys:", observation.keys())
|
|||||||
print("Processed keys:", processed_obs.keys())
|
print("Processed keys:", processed_obs.keys())
|
||||||
print("Image shape:", processed_obs["observation.images.camera_front"].shape) # [1, 3, 480, 640]
|
print("Image shape:", processed_obs["observation.images.camera_front"].shape) # [1, 3, 480, 640]
|
||||||
print("Image dtype:", processed_obs["observation.images.camera_front"].dtype) # torch.float32
|
print("Image dtype:", processed_obs["observation.images.camera_front"].dtype) # torch.float32
|
||||||
print("Image range:", processed_obs["observation.images.camera_front"].min().item(),
|
print("Image range:", processed_obs["observation.images.camera_front"].min().item(),
|
||||||
"to", processed_obs["observation.images.camera_front"].max().item()) # 0.0 to 1.0
|
"to", processed_obs["observation.images.camera_front"].max().item()) # 0.0 to 1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -197,18 +197,18 @@ import torch.nn.functional as F
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ImagePadder:
|
class ImagePadder:
|
||||||
"""Pad images to a standard size for batch processing."""
|
"""Pad images to a standard size for batch processing."""
|
||||||
|
|
||||||
target_height: int = 224
|
target_height: int = 224
|
||||||
target_width: int = 224
|
target_width: int = 224
|
||||||
pad_value: float = 0.0
|
pad_value: float = 0.0
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
"""Main processing method - required for all steps."""
|
"""Main processing method - required for all steps."""
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
if obs is None:
|
if obs is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
# Process all image observations
|
# Process all image observations
|
||||||
for key in list(obs.keys()):
|
for key in list(obs.keys()):
|
||||||
if key.startswith("observation.images."):
|
if key.startswith("observation.images."):
|
||||||
@@ -217,22 +217,22 @@ class ImagePadder:
|
|||||||
_, _, h, w = img.shape
|
_, _, h, w = img.shape
|
||||||
pad_h = max(0, self.target_height - h)
|
pad_h = max(0, self.target_height - h)
|
||||||
pad_w = max(0, self.target_width - w)
|
pad_w = max(0, self.target_width - w)
|
||||||
|
|
||||||
if pad_h > 0 or pad_w > 0:
|
if pad_h > 0 or pad_w > 0:
|
||||||
# Pad symmetrically
|
# Pad symmetrically
|
||||||
pad_left = pad_w // 2
|
pad_left = pad_w // 2
|
||||||
pad_right = pad_w - pad_left
|
pad_right = pad_w - pad_left
|
||||||
pad_top = pad_h // 2
|
pad_top = pad_h // 2
|
||||||
pad_bottom = pad_h - pad_top
|
pad_bottom = pad_h - pad_top
|
||||||
|
|
||||||
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
|
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
|
||||||
mode='constant', value=self.pad_value)
|
mode='constant', value=self.pad_value)
|
||||||
|
|
||||||
obs[key] = img
|
obs[key] = img
|
||||||
|
|
||||||
# Return modified transition
|
# Return modified transition
|
||||||
return (obs, *transition[1:])
|
return (obs, *transition[1:])
|
||||||
|
|
||||||
def get_config(self) -> Dict[str, Any]:
|
def get_config(self) -> Dict[str, Any]:
|
||||||
"""Return JSON-serializable configuration - required for save/load."""
|
"""Return JSON-serializable configuration - required for save/load."""
|
||||||
return {
|
return {
|
||||||
@@ -240,17 +240,17 @@ class ImagePadder:
|
|||||||
"target_width": self.target_width,
|
"target_width": self.target_width,
|
||||||
"pad_value": self.pad_value
|
"pad_value": self.pad_value
|
||||||
}
|
}
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
"""Return tensor state - only include torch.Tensor objects!"""
|
"""Return tensor state - only include torch.Tensor objects!"""
|
||||||
# This step has no learnable parameters
|
# This step has no learnable parameters
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||||
"""Load tensor state - required if state_dict returns non-empty dict."""
|
"""Load tensor state - required if state_dict returns non-empty dict."""
|
||||||
# Nothing to load for this step
|
# Nothing to load for this step
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset internal state at episode boundaries - required for stateful steps."""
|
"""Reset internal state at episode boundaries - required for stateful steps."""
|
||||||
# This step is stateless, so nothing to reset
|
# This step is stateless, so nothing to reset
|
||||||
@@ -265,15 +265,15 @@ These two methods serve different purposes and it's crucial to use them correctl
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AdaptiveNormalizer:
|
class AdaptiveNormalizer:
|
||||||
"""Example showing proper use of get_config and state_dict."""
|
"""Example showing proper use of get_config and state_dict."""
|
||||||
|
|
||||||
learning_rate: float = 0.01
|
learning_rate: float = 0.01
|
||||||
epsilon: float = 1e-8
|
epsilon: float = 1e-8
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.running_mean = None
|
self.running_mean = None
|
||||||
self.running_var = None
|
self.running_var = None
|
||||||
self.num_samples = 0 # Python int, not tensor
|
self.num_samples = 0 # Python int, not tensor
|
||||||
|
|
||||||
def get_config(self) -> Dict[str, Any]:
|
def get_config(self) -> Dict[str, Any]:
|
||||||
"""ONLY Python objects that can be JSON serialized!"""
|
"""ONLY Python objects that can be JSON serialized!"""
|
||||||
return {
|
return {
|
||||||
@@ -282,7 +282,7 @@ class AdaptiveNormalizer:
|
|||||||
"num_samples": self.num_samples, # int ✓
|
"num_samples": self.num_samples, # int ✓
|
||||||
# "running_mean": self.running_mean, # torch.Tensor ✗ WRONG!
|
# "running_mean": self.running_mean, # torch.Tensor ✗ WRONG!
|
||||||
}
|
}
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
"""ONLY torch.Tensor objects!"""
|
"""ONLY torch.Tensor objects!"""
|
||||||
if self.running_mean is None:
|
if self.running_mean is None:
|
||||||
@@ -294,7 +294,7 @@ class AdaptiveNormalizer:
|
|||||||
# Instead, convert to tensor if needed:
|
# Instead, convert to tensor if needed:
|
||||||
"num_samples_tensor": torch.tensor(self.num_samples)
|
"num_samples_tensor": torch.tensor(self.num_samples)
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||||
"""Load tensors and convert back to Python types if needed."""
|
"""Load tensors and convert back to Python types if needed."""
|
||||||
self.running_mean = state.get("running_mean")
|
self.running_mean = state.get("running_mean")
|
||||||
@@ -311,14 +311,14 @@ The `complementary_data` field is perfect for passing information between steps
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ImageStatisticsCalculator:
|
class ImageStatisticsCalculator:
|
||||||
"""Calculate image statistics and pass to next steps."""
|
"""Calculate image statistics and pass to next steps."""
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||||
|
|
||||||
if obs is None:
|
if obs is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
# Calculate statistics for all images
|
# Calculate statistics for all images
|
||||||
image_stats = {}
|
image_stats = {}
|
||||||
for key in obs:
|
for key in obs:
|
||||||
@@ -331,10 +331,10 @@ class ImageStatisticsCalculator:
|
|||||||
"max": img.max().item(),
|
"max": img.max().item(),
|
||||||
}
|
}
|
||||||
image_stats[key] = stats
|
image_stats[key] = stats
|
||||||
|
|
||||||
# Store in complementary_data for next steps
|
# Store in complementary_data for next steps
|
||||||
comp_data["image_statistics"] = image_stats
|
comp_data["image_statistics"] = image_stats
|
||||||
|
|
||||||
# Return transition with updated complementary_data
|
# Return transition with updated complementary_data
|
||||||
return (
|
return (
|
||||||
obs,
|
obs,
|
||||||
@@ -346,30 +346,30 @@ class ImageStatisticsCalculator:
|
|||||||
comp_data # Updated complementary_data
|
comp_data # Updated complementary_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdaptiveBrightnessAdjuster:
|
class AdaptiveBrightnessAdjuster:
|
||||||
"""Adjust brightness based on statistics from previous step."""
|
"""Adjust brightness based on statistics from previous step."""
|
||||||
|
|
||||||
target_brightness: float = 0.5
|
target_brightness: float = 0.5
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
|
||||||
|
|
||||||
if obs is None or "image_statistics" not in comp_data:
|
if obs is None or "image_statistics" not in comp_data:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
# Use statistics from previous step
|
# Use statistics from previous step
|
||||||
image_stats = comp_data["image_statistics"]
|
image_stats = comp_data["image_statistics"]
|
||||||
|
|
||||||
for key in obs:
|
for key in obs:
|
||||||
if key.startswith("observation.images.") and key in image_stats:
|
if key.startswith("observation.images.") and key in image_stats:
|
||||||
current_mean = image_stats[key]["mean"]
|
current_mean = image_stats[key]["mean"]
|
||||||
brightness_adjust = self.target_brightness - current_mean
|
brightness_adjust = self.target_brightness - current_mean
|
||||||
|
|
||||||
# Adjust brightness
|
# Adjust brightness
|
||||||
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
|
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
|
||||||
|
|
||||||
return (obs, *transition[1:])
|
return (obs, *transition[1:])
|
||||||
|
|
||||||
# Use them together
|
# Use them together
|
||||||
@@ -394,28 +394,28 @@ import numpy as np
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DeviceMover:
|
class DeviceMover:
|
||||||
"""Move all tensors to specified device."""
|
"""Move all tensors to specified device."""
|
||||||
|
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
if obs is None:
|
if obs is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
# Move all tensor observations to device
|
# Move all tensor observations to device
|
||||||
for key, value in obs.items():
|
for key, value in obs.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
obs[key] = value.to(self.device)
|
obs[key] = value.to(self.device)
|
||||||
|
|
||||||
# Also handle action if present
|
# Also handle action if present
|
||||||
action = transition[TransitionIndex.ACTION]
|
action = transition[TransitionIndex.ACTION]
|
||||||
if action is not None and isinstance(action, torch.Tensor):
|
if action is not None and isinstance(action, torch.Tensor):
|
||||||
action = action.to(self.device)
|
action = action.to(self.device)
|
||||||
return (obs, action, *transition[2:])
|
return (obs, action, *transition[2:])
|
||||||
|
|
||||||
return (obs, *transition[1:])
|
return (obs, *transition[1:])
|
||||||
|
|
||||||
def get_config(self) -> Dict[str, Any]:
|
def get_config(self) -> Dict[str, Any]:
|
||||||
return {"device": str(self.device)}
|
return {"device": str(self.device)}
|
||||||
|
|
||||||
@@ -423,52 +423,52 @@ class DeviceMover:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RunningNormalizer:
|
class RunningNormalizer:
|
||||||
"""Normalize using running statistics with proper device handling."""
|
"""Normalize using running statistics with proper device handling."""
|
||||||
|
|
||||||
feature_dim: int
|
feature_dim: int
|
||||||
momentum: float = 0.1
|
momentum: float = 0.1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize as None - will be created on first call with correct device
|
# Initialize as None - will be created on first call with correct device
|
||||||
self.running_mean = None
|
self.running_mean = None
|
||||||
self.running_var = None
|
self.running_var = None
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
if obs is None or "observation.state" not in obs:
|
if obs is None or "observation.state" not in obs:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
state = obs["observation.state"]
|
state = obs["observation.state"]
|
||||||
|
|
||||||
# Initialize on first call with correct device
|
# Initialize on first call with correct device
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
device = state.device
|
device = state.device
|
||||||
self.running_mean = torch.zeros(self.feature_dim, device=device)
|
self.running_mean = torch.zeros(self.feature_dim, device=device)
|
||||||
self.running_var = torch.ones(self.feature_dim, device=device)
|
self.running_var = torch.ones(self.feature_dim, device=device)
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
# Update statistics
|
# Update statistics
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
batch_mean = state.mean(dim=0)
|
batch_mean = state.mean(dim=0)
|
||||||
batch_var = state.var(dim=0, unbiased=False)
|
batch_var = state.var(dim=0, unbiased=False)
|
||||||
|
|
||||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
|
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
|
||||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
|
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
|
||||||
|
|
||||||
# Normalize
|
# Normalize
|
||||||
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
|
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
|
||||||
obs["observation.state"] = state_normalized
|
obs["observation.state"] = state_normalized
|
||||||
|
|
||||||
return (obs, *transition[1:])
|
return (obs, *transition[1:])
|
||||||
|
|
||||||
def get_config(self) -> Dict[str, Any]:
|
def get_config(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"feature_dim": self.feature_dim,
|
"feature_dim": self.feature_dim,
|
||||||
"momentum": self.momentum,
|
"momentum": self.momentum,
|
||||||
"initialized": self.initialized
|
"initialized": self.initialized
|
||||||
}
|
}
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
return {}
|
return {}
|
||||||
@@ -476,13 +476,13 @@ class RunningNormalizer:
|
|||||||
"running_mean": self.running_mean,
|
"running_mean": self.running_mean,
|
||||||
"running_var": self.running_var
|
"running_var": self.running_var
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
|
||||||
if state:
|
if state:
|
||||||
self.running_mean = state["running_mean"]
|
self.running_mean = state["running_mean"]
|
||||||
self.running_var = state["running_var"]
|
self.running_var = state["running_var"]
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
# Don't reset statistics - they persist across episodes
|
# Don't reset statistics - they persist across episodes
|
||||||
pass
|
pass
|
||||||
@@ -524,29 +524,29 @@ When your environment and policy speak different "languages":
|
|||||||
@dataclass
|
@dataclass
|
||||||
class KeyRemapper:
|
class KeyRemapper:
|
||||||
"""Rename observation keys to match policy expectations."""
|
"""Rename observation keys to match policy expectations."""
|
||||||
|
|
||||||
key_mapping: Dict[str, str] = field(default_factory=lambda: {
|
key_mapping: Dict[str, str] = field(default_factory=lambda: {
|
||||||
"rgb_camera_front": "observation.images.wrist",
|
"rgb_camera_front": "observation.images.wrist",
|
||||||
"joint_positions": "observation.state",
|
"joint_positions": "observation.state",
|
||||||
"gripper_state": "observation.gripper"
|
"gripper_state": "observation.gripper"
|
||||||
})
|
})
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
if obs is None:
|
if obs is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
# Create new observation with renamed keys
|
# Create new observation with renamed keys
|
||||||
renamed_obs = {}
|
renamed_obs = {}
|
||||||
for old_key, new_key in self.key_mapping.items():
|
for old_key, new_key in self.key_mapping.items():
|
||||||
if old_key in obs:
|
if old_key in obs:
|
||||||
renamed_obs[new_key] = obs[old_key]
|
renamed_obs[new_key] = obs[old_key]
|
||||||
|
|
||||||
# Keep any unmapped keys as-is
|
# Keep any unmapped keys as-is
|
||||||
for key, value in obs.items():
|
for key, value in obs.items():
|
||||||
if key not in self.key_mapping:
|
if key not in self.key_mapping:
|
||||||
renamed_obs[key] = value
|
renamed_obs[key] = value
|
||||||
|
|
||||||
return (renamed_obs, *transition[1:])
|
return (renamed_obs, *transition[1:])
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -559,15 +559,15 @@ When you need to crop and resize images to focus on the manipulation workspace:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class WorkspaceCropper:
|
class WorkspaceCropper:
|
||||||
"""Crop and resize images to focus on robot workspace."""
|
"""Crop and resize images to focus on robot workspace."""
|
||||||
|
|
||||||
crop_bbox: Tuple[int, int, int, int] = (400, 200, 1200, 800) # (x1, y1, x2, y2)
|
crop_bbox: Tuple[int, int, int, int] = (400, 200, 1200, 800) # (x1, y1, x2, y2)
|
||||||
output_size: Tuple[int, int] = (224, 224)
|
output_size: Tuple[int, int] = (224, 224)
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
if obs is None:
|
if obs is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
for key in list(obs.keys()):
|
for key in list(obs.keys()):
|
||||||
if key.startswith("observation.images."):
|
if key.startswith("observation.images."):
|
||||||
img = obs[key]
|
img = obs[key]
|
||||||
@@ -576,13 +576,13 @@ class WorkspaceCropper:
|
|||||||
img_cropped = img[:, :, y1:y2, x1:x2]
|
img_cropped = img[:, :, y1:y2, x1:x2]
|
||||||
# Resize to expected dimensions
|
# Resize to expected dimensions
|
||||||
img_resized = F.interpolate(
|
img_resized = F.interpolate(
|
||||||
img_cropped,
|
img_cropped,
|
||||||
size=self.output_size,
|
size=self.output_size,
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=False
|
align_corners=False
|
||||||
)
|
)
|
||||||
obs[key] = img_resized
|
obs[key] = img_resized
|
||||||
|
|
||||||
return (obs, *transition[1:])
|
return (obs, *transition[1:])
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -596,7 +596,7 @@ Now you can compose these steps into robot-specific pipelines:
|
|||||||
robot_a_processor = RobotProcessor([
|
robot_a_processor = RobotProcessor([
|
||||||
# Observation preprocessing
|
# Observation preprocessing
|
||||||
KeyRemapper({
|
KeyRemapper({
|
||||||
"wrist_rgb": "observation.images.wrist",
|
"wrist_rgb": "observation.images.wrist",
|
||||||
"arm_joints": "observation.state"
|
"arm_joints": "observation.state"
|
||||||
}),
|
}),
|
||||||
ImageProcessor(),
|
ImageProcessor(),
|
||||||
@@ -604,7 +604,7 @@ robot_a_processor = RobotProcessor([
|
|||||||
StateProcessor(),
|
StateProcessor(),
|
||||||
VelocityCalculator(state_key="observation.state"),
|
VelocityCalculator(state_key="observation.state"),
|
||||||
DeviceMover("cuda"),
|
DeviceMover("cuda"),
|
||||||
|
|
||||||
# Action postprocessing
|
# Action postprocessing
|
||||||
ActionSmoother(alpha=0.2),
|
ActionSmoother(alpha=0.2),
|
||||||
], name="RobotA_ACT_Processor")
|
], name="RobotA_ACT_Processor")
|
||||||
@@ -618,10 +618,10 @@ robot_b_processor = RobotProcessor([
|
|||||||
}),
|
}),
|
||||||
ImageProcessor(),
|
ImageProcessor(),
|
||||||
WorkspaceCropper(crop_bbox=(100, 50, 1100, 950), output_size=(224, 224)),
|
WorkspaceCropper(crop_bbox=(100, 50, 1100, 950), output_size=(224, 224)),
|
||||||
StateProcessor(),
|
StateProcessor(),
|
||||||
VelocityCalculator(state_key="observation.state"),
|
VelocityCalculator(state_key="observation.state"),
|
||||||
DeviceMover("cuda"),
|
DeviceMover("cuda"),
|
||||||
|
|
||||||
ActionSmoother(alpha=0.3), # Different smoothing
|
ActionSmoother(alpha=0.3), # Different smoothing
|
||||||
], name="RobotB_ACT_Processor")
|
], name="RobotB_ACT_Processor")
|
||||||
|
|
||||||
@@ -645,11 +645,11 @@ The beauty of this approach is that:
|
|||||||
```python
|
```python
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
# Always check if observation exists
|
# Always check if observation exists
|
||||||
if obs is None:
|
if obs is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
# Also check for specific keys
|
# Also check for specific keys
|
||||||
if "observation.state" not in obs:
|
if "observation.state" not in obs:
|
||||||
return transition
|
return transition
|
||||||
@@ -668,11 +668,11 @@ return (modified_obs, None, 0.0, False, False, {}, {})
|
|||||||
```python
|
```python
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition[TransitionIndex.OBSERVATION]
|
obs = transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
if self.store_previous:
|
if self.store_previous:
|
||||||
# Good - clone to avoid reference issues
|
# Good - clone to avoid reference issues
|
||||||
self.previous_state = obs["observation.state"].clone()
|
self.previous_state = obs["observation.state"].clone()
|
||||||
|
|
||||||
# Bad - stores reference that might be modified
|
# Bad - stores reference that might be modified
|
||||||
# self.previous_state = obs["observation.state"]
|
# self.previous_state = obs["observation.state"]
|
||||||
```
|
```
|
||||||
@@ -682,7 +682,7 @@ def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
if self.buffer is None:
|
if self.buffer is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Good - clone to avoid memory sharing issues
|
# Good - clone to avoid memory sharing issues
|
||||||
return {"buffer": self.buffer.clone()}
|
return {"buffer": self.buffer.clone()}
|
||||||
```
|
```
|
||||||
@@ -713,16 +713,16 @@ class ActionClipper:
|
|||||||
"""Clip actions to safe ranges."""
|
"""Clip actions to safe ranges."""
|
||||||
min_value: float = -1.0
|
min_value: float = -1.0
|
||||||
max_value: float = 1.0
|
max_value: float = 1.0
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
action = transition[TransitionIndex.ACTION]
|
action = transition[TransitionIndex.ACTION]
|
||||||
|
|
||||||
if action is not None:
|
if action is not None:
|
||||||
action = torch.clamp(action, self.min_value, self.max_value)
|
action = torch.clamp(action, self.min_value, self.max_value)
|
||||||
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
|
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
|
||||||
|
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
def get_config(self) -> Dict[str, Any]:
|
def get_config(self) -> Dict[str, Any]:
|
||||||
return {"min_value": self.min_value, "max_value": self.max_value}
|
return {"min_value": self.min_value, "max_value": self.max_value}
|
||||||
|
|
||||||
@@ -736,7 +736,7 @@ policy = ACTPolicy.from_pretrained("lerobot/act_aloha_sim_transfer_cube_human")
|
|||||||
|
|
||||||
# Move everything to GPU
|
# Move everything to GPU
|
||||||
preprocessor = preprocessor.to("cuda")
|
preprocessor = preprocessor.to("cuda")
|
||||||
postprocessor = postprocessor.to("cuda")
|
postprocessor = postprocessor.to("cuda")
|
||||||
policy = policy.to("cuda")
|
policy = policy.to("cuda")
|
||||||
|
|
||||||
# Control loop
|
# Control loop
|
||||||
@@ -745,42 +745,42 @@ obs, info = env.reset()
|
|||||||
|
|
||||||
for episode in range(10):
|
for episode in range(10):
|
||||||
print(f"Episode {episode + 1}")
|
print(f"Episode {episode + 1}")
|
||||||
|
|
||||||
for step in range(1000):
|
for step in range(1000):
|
||||||
# Create transition with raw observation
|
# Create transition with raw observation
|
||||||
transition = (obs, None, 0.0, False, False, info, {"step": step})
|
transition = (obs, None, 0.0, False, False, info, {"step": step})
|
||||||
|
|
||||||
# Preprocess
|
# Preprocess
|
||||||
processed_transition = preprocessor(transition)
|
processed_transition = preprocessor(transition)
|
||||||
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
|
||||||
|
|
||||||
# Get action from policy
|
# Get action from policy
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action = policy.select_action(processed_obs)
|
action = policy.select_action(processed_obs)
|
||||||
|
|
||||||
# Postprocess action
|
# Postprocess action
|
||||||
action_transition = (
|
action_transition = (
|
||||||
processed_obs,
|
processed_obs,
|
||||||
action,
|
action,
|
||||||
0.0,
|
0.0,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
info,
|
info,
|
||||||
{"raw_action": action.clone()} # Store raw action in complementary_data
|
{"raw_action": action.clone()} # Store raw action in complementary_data
|
||||||
)
|
)
|
||||||
processed_action_transition = postprocessor(action_transition)
|
processed_action_transition = postprocessor(action_transition)
|
||||||
final_action = processed_action_transition[TransitionIndex.ACTION]
|
final_action = processed_action_transition[TransitionIndex.ACTION]
|
||||||
|
|
||||||
# Execute action
|
# Execute action
|
||||||
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
|
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
|
||||||
|
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
# Reset at episode boundary
|
# Reset at episode boundary
|
||||||
preprocessor.reset()
|
preprocessor.reset()
|
||||||
postprocessor.reset()
|
postprocessor.reset()
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
break
|
break
|
||||||
|
|
||||||
# Save preprocessor with learned statistics
|
# Save preprocessor with learned statistics
|
||||||
preprocessor.save_pretrained(f"./checkpoints/preprocessor_ep{episode}")
|
preprocessor.save_pretrained(f"./checkpoints/preprocessor_ep{episode}")
|
||||||
|
|
||||||
@@ -844,4 +844,4 @@ RobotProcessor provides a powerful, modular approach to data preprocessing in ro
|
|||||||
|
|
||||||
By following these patterns, your preprocessing code becomes more maintainable, shareable, and robust.
|
By following these patterns, your preprocessing code becomes more maintainable, shareable, and robust.
|
||||||
|
|
||||||
For the full API reference, see the [RobotProcessor API documentation](/api/processor).
|
For the full API reference, see the [RobotProcessor API documentation](/api/processor).
|
||||||
|
|||||||
Reference in New Issue
Block a user