[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-03 13:22:41 +00:00
committed by Adil Zouitine
parent cd13f1ecfd
commit ac742c9f0d
+96 -96
View File
@@ -60,7 +60,7 @@ If you've worked with robot learning before, you've likely written preprocessing
# Traditional procedural approach - hard to maintain and share # Traditional procedural approach - hard to maintain and share
def preprocess_observation(obs, device='cuda'): def preprocess_observation(obs, device='cuda'):
processed_obs = {} processed_obs = {}
# Process images # Process images
if "pixels" in obs: if "pixels" in obs:
for cam_name, img in obs["pixels"].items(): for cam_name, img in obs["pixels"].items():
@@ -76,13 +76,13 @@ def preprocess_observation(obs, device='cuda'):
if img.shape[-2:] != (224, 224): if img.shape[-2:] != (224, 224):
img = F.pad(img, (0, 224 - img.shape[-1], 0, 224 - img.shape[-2])) img = F.pad(img, (0, 224 - img.shape[-1], 0, 224 - img.shape[-2]))
processed_obs[f"observation.images.{cam_name}"] = img processed_obs[f"observation.images.{cam_name}"] = img
# Process state # Process state
if "agent_pos" in obs: if "agent_pos" in obs:
state = torch.from_numpy(obs["agent_pos"]).float() state = torch.from_numpy(obs["agent_pos"]).float()
state = state.unsqueeze(0).to(device) state = state.unsqueeze(0).to(device)
processed_obs["observation.state"] = state processed_obs["observation.state"] = state
return processed_obs return processed_obs
``` ```
@@ -180,7 +180,7 @@ print("Original keys:", observation.keys())
print("Processed keys:", processed_obs.keys()) print("Processed keys:", processed_obs.keys())
print("Image shape:", processed_obs["observation.images.camera_front"].shape) # [1, 3, 480, 640] print("Image shape:", processed_obs["observation.images.camera_front"].shape) # [1, 3, 480, 640]
print("Image dtype:", processed_obs["observation.images.camera_front"].dtype) # torch.float32 print("Image dtype:", processed_obs["observation.images.camera_front"].dtype) # torch.float32
print("Image range:", processed_obs["observation.images.camera_front"].min().item(), print("Image range:", processed_obs["observation.images.camera_front"].min().item(),
"to", processed_obs["observation.images.camera_front"].max().item()) # 0.0 to 1.0 "to", processed_obs["observation.images.camera_front"].max().item()) # 0.0 to 1.0
``` ```
@@ -197,18 +197,18 @@ import torch.nn.functional as F
@dataclass @dataclass
class ImagePadder: class ImagePadder:
"""Pad images to a standard size for batch processing.""" """Pad images to a standard size for batch processing."""
target_height: int = 224 target_height: int = 224
target_width: int = 224 target_width: int = 224
pad_value: float = 0.0 pad_value: float = 0.0
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Main processing method - required for all steps.""" """Main processing method - required for all steps."""
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
if obs is None: if obs is None:
return transition return transition
# Process all image observations # Process all image observations
for key in list(obs.keys()): for key in list(obs.keys()):
if key.startswith("observation.images."): if key.startswith("observation.images."):
@@ -217,22 +217,22 @@ class ImagePadder:
_, _, h, w = img.shape _, _, h, w = img.shape
pad_h = max(0, self.target_height - h) pad_h = max(0, self.target_height - h)
pad_w = max(0, self.target_width - w) pad_w = max(0, self.target_width - w)
if pad_h > 0 or pad_w > 0: if pad_h > 0 or pad_w > 0:
# Pad symmetrically # Pad symmetrically
pad_left = pad_w // 2 pad_left = pad_w // 2
pad_right = pad_w - pad_left pad_right = pad_w - pad_left
pad_top = pad_h // 2 pad_top = pad_h // 2
pad_bottom = pad_h - pad_top pad_bottom = pad_h - pad_top
img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), img = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom),
mode='constant', value=self.pad_value) mode='constant', value=self.pad_value)
obs[key] = img obs[key] = img
# Return modified transition # Return modified transition
return (obs, *transition[1:]) return (obs, *transition[1:])
def get_config(self) -> Dict[str, Any]: def get_config(self) -> Dict[str, Any]:
"""Return JSON-serializable configuration - required for save/load.""" """Return JSON-serializable configuration - required for save/load."""
return { return {
@@ -240,17 +240,17 @@ class ImagePadder:
"target_width": self.target_width, "target_width": self.target_width,
"pad_value": self.pad_value "pad_value": self.pad_value
} }
def state_dict(self) -> Dict[str, torch.Tensor]: def state_dict(self) -> Dict[str, torch.Tensor]:
"""Return tensor state - only include torch.Tensor objects!""" """Return tensor state - only include torch.Tensor objects!"""
# This step has no learnable parameters # This step has no learnable parameters
return {} return {}
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
"""Load tensor state - required if state_dict returns non-empty dict.""" """Load tensor state - required if state_dict returns non-empty dict."""
# Nothing to load for this step # Nothing to load for this step
pass pass
def reset(self) -> None: def reset(self) -> None:
"""Reset internal state at episode boundaries - required for stateful steps.""" """Reset internal state at episode boundaries - required for stateful steps."""
# This step is stateless, so nothing to reset # This step is stateless, so nothing to reset
@@ -265,15 +265,15 @@ These two methods serve different purposes and it's crucial to use them correctl
@dataclass @dataclass
class AdaptiveNormalizer: class AdaptiveNormalizer:
"""Example showing proper use of get_config and state_dict.""" """Example showing proper use of get_config and state_dict."""
learning_rate: float = 0.01 learning_rate: float = 0.01
epsilon: float = 1e-8 epsilon: float = 1e-8
def __post_init__(self): def __post_init__(self):
self.running_mean = None self.running_mean = None
self.running_var = None self.running_var = None
self.num_samples = 0 # Python int, not tensor self.num_samples = 0 # Python int, not tensor
def get_config(self) -> Dict[str, Any]: def get_config(self) -> Dict[str, Any]:
"""ONLY Python objects that can be JSON serialized!""" """ONLY Python objects that can be JSON serialized!"""
return { return {
@@ -282,7 +282,7 @@ class AdaptiveNormalizer:
"num_samples": self.num_samples, # int ✓ "num_samples": self.num_samples, # int ✓
# "running_mean": self.running_mean, # torch.Tensor ✗ WRONG! # "running_mean": self.running_mean, # torch.Tensor ✗ WRONG!
} }
def state_dict(self) -> Dict[str, torch.Tensor]: def state_dict(self) -> Dict[str, torch.Tensor]:
"""ONLY torch.Tensor objects!""" """ONLY torch.Tensor objects!"""
if self.running_mean is None: if self.running_mean is None:
@@ -294,7 +294,7 @@ class AdaptiveNormalizer:
# Instead, convert to tensor if needed: # Instead, convert to tensor if needed:
"num_samples_tensor": torch.tensor(self.num_samples) "num_samples_tensor": torch.tensor(self.num_samples)
} }
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
"""Load tensors and convert back to Python types if needed.""" """Load tensors and convert back to Python types if needed."""
self.running_mean = state.get("running_mean") self.running_mean = state.get("running_mean")
@@ -311,14 +311,14 @@ The `complementary_data` field is perfect for passing information between steps
@dataclass @dataclass
class ImageStatisticsCalculator: class ImageStatisticsCalculator:
"""Calculate image statistics and pass to next steps.""" """Calculate image statistics and pass to next steps."""
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {} comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
if obs is None: if obs is None:
return transition return transition
# Calculate statistics for all images # Calculate statistics for all images
image_stats = {} image_stats = {}
for key in obs: for key in obs:
@@ -331,10 +331,10 @@ class ImageStatisticsCalculator:
"max": img.max().item(), "max": img.max().item(),
} }
image_stats[key] = stats image_stats[key] = stats
# Store in complementary_data for next steps # Store in complementary_data for next steps
comp_data["image_statistics"] = image_stats comp_data["image_statistics"] = image_stats
# Return transition with updated complementary_data # Return transition with updated complementary_data
return ( return (
obs, obs,
@@ -346,30 +346,30 @@ class ImageStatisticsCalculator:
comp_data # Updated complementary_data comp_data # Updated complementary_data
) )
@dataclass @dataclass
class AdaptiveBrightnessAdjuster: class AdaptiveBrightnessAdjuster:
"""Adjust brightness based on statistics from previous step.""" """Adjust brightness based on statistics from previous step."""
target_brightness: float = 0.5 target_brightness: float = 0.5
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {} comp_data = transition[TransitionIndex.COMPLEMENTARY_DATA] or {}
if obs is None or "image_statistics" not in comp_data: if obs is None or "image_statistics" not in comp_data:
return transition return transition
# Use statistics from previous step # Use statistics from previous step
image_stats = comp_data["image_statistics"] image_stats = comp_data["image_statistics"]
for key in obs: for key in obs:
if key.startswith("observation.images.") and key in image_stats: if key.startswith("observation.images.") and key in image_stats:
current_mean = image_stats[key]["mean"] current_mean = image_stats[key]["mean"]
brightness_adjust = self.target_brightness - current_mean brightness_adjust = self.target_brightness - current_mean
# Adjust brightness # Adjust brightness
obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1) obs[key] = torch.clamp(obs[key] + brightness_adjust, 0, 1)
return (obs, *transition[1:]) return (obs, *transition[1:])
# Use them together # Use them together
@@ -394,28 +394,28 @@ import numpy as np
@dataclass @dataclass
class DeviceMover: class DeviceMover:
"""Move all tensors to specified device.""" """Move all tensors to specified device."""
device: str = "cuda" device: str = "cuda"
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
if obs is None: if obs is None:
return transition return transition
# Move all tensor observations to device # Move all tensor observations to device
for key, value in obs.items(): for key, value in obs.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
obs[key] = value.to(self.device) obs[key] = value.to(self.device)
# Also handle action if present # Also handle action if present
action = transition[TransitionIndex.ACTION] action = transition[TransitionIndex.ACTION]
if action is not None and isinstance(action, torch.Tensor): if action is not None and isinstance(action, torch.Tensor):
action = action.to(self.device) action = action.to(self.device)
return (obs, action, *transition[2:]) return (obs, action, *transition[2:])
return (obs, *transition[1:]) return (obs, *transition[1:])
def get_config(self) -> Dict[str, Any]: def get_config(self) -> Dict[str, Any]:
return {"device": str(self.device)} return {"device": str(self.device)}
@@ -423,52 +423,52 @@ class DeviceMover:
@dataclass @dataclass
class RunningNormalizer: class RunningNormalizer:
"""Normalize using running statistics with proper device handling.""" """Normalize using running statistics with proper device handling."""
feature_dim: int feature_dim: int
momentum: float = 0.1 momentum: float = 0.1
def __post_init__(self): def __post_init__(self):
# Initialize as None - will be created on first call with correct device # Initialize as None - will be created on first call with correct device
self.running_mean = None self.running_mean = None
self.running_var = None self.running_var = None
self.initialized = False self.initialized = False
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
if obs is None or "observation.state" not in obs: if obs is None or "observation.state" not in obs:
return transition return transition
state = obs["observation.state"] state = obs["observation.state"]
# Initialize on first call with correct device # Initialize on first call with correct device
if not self.initialized: if not self.initialized:
device = state.device device = state.device
self.running_mean = torch.zeros(self.feature_dim, device=device) self.running_mean = torch.zeros(self.feature_dim, device=device)
self.running_var = torch.ones(self.feature_dim, device=device) self.running_var = torch.ones(self.feature_dim, device=device)
self.initialized = True self.initialized = True
# Update statistics # Update statistics
with torch.no_grad(): with torch.no_grad():
batch_mean = state.mean(dim=0) batch_mean = state.mean(dim=0)
batch_var = state.var(dim=0, unbiased=False) batch_var = state.var(dim=0, unbiased=False)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
# Normalize # Normalize
state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt() state_normalized = (state - self.running_mean) / (self.running_var + 1e-8).sqrt()
obs["observation.state"] = state_normalized obs["observation.state"] = state_normalized
return (obs, *transition[1:]) return (obs, *transition[1:])
def get_config(self) -> Dict[str, Any]: def get_config(self) -> Dict[str, Any]:
return { return {
"feature_dim": self.feature_dim, "feature_dim": self.feature_dim,
"momentum": self.momentum, "momentum": self.momentum,
"initialized": self.initialized "initialized": self.initialized
} }
def state_dict(self) -> Dict[str, torch.Tensor]: def state_dict(self) -> Dict[str, torch.Tensor]:
if not self.initialized: if not self.initialized:
return {} return {}
@@ -476,13 +476,13 @@ class RunningNormalizer:
"running_mean": self.running_mean, "running_mean": self.running_mean,
"running_var": self.running_var "running_var": self.running_var
} }
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None:
if state: if state:
self.running_mean = state["running_mean"] self.running_mean = state["running_mean"]
self.running_var = state["running_var"] self.running_var = state["running_var"]
self.initialized = True self.initialized = True
def reset(self) -> None: def reset(self) -> None:
# Don't reset statistics - they persist across episodes # Don't reset statistics - they persist across episodes
pass pass
@@ -524,29 +524,29 @@ When your environment and policy speak different "languages":
@dataclass @dataclass
class KeyRemapper: class KeyRemapper:
"""Rename observation keys to match policy expectations.""" """Rename observation keys to match policy expectations."""
key_mapping: Dict[str, str] = field(default_factory=lambda: { key_mapping: Dict[str, str] = field(default_factory=lambda: {
"rgb_camera_front": "observation.images.wrist", "rgb_camera_front": "observation.images.wrist",
"joint_positions": "observation.state", "joint_positions": "observation.state",
"gripper_state": "observation.gripper" "gripper_state": "observation.gripper"
}) })
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
if obs is None: if obs is None:
return transition return transition
# Create new observation with renamed keys # Create new observation with renamed keys
renamed_obs = {} renamed_obs = {}
for old_key, new_key in self.key_mapping.items(): for old_key, new_key in self.key_mapping.items():
if old_key in obs: if old_key in obs:
renamed_obs[new_key] = obs[old_key] renamed_obs[new_key] = obs[old_key]
# Keep any unmapped keys as-is # Keep any unmapped keys as-is
for key, value in obs.items(): for key, value in obs.items():
if key not in self.key_mapping: if key not in self.key_mapping:
renamed_obs[key] = value renamed_obs[key] = value
return (renamed_obs, *transition[1:]) return (renamed_obs, *transition[1:])
``` ```
@@ -559,15 +559,15 @@ When you need to crop and resize images to focus on the manipulation workspace:
@dataclass @dataclass
class WorkspaceCropper: class WorkspaceCropper:
"""Crop and resize images to focus on robot workspace.""" """Crop and resize images to focus on robot workspace."""
crop_bbox: Tuple[int, int, int, int] = (400, 200, 1200, 800) # (x1, y1, x2, y2) crop_bbox: Tuple[int, int, int, int] = (400, 200, 1200, 800) # (x1, y1, x2, y2)
output_size: Tuple[int, int] = (224, 224) output_size: Tuple[int, int] = (224, 224)
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
if obs is None: if obs is None:
return transition return transition
for key in list(obs.keys()): for key in list(obs.keys()):
if key.startswith("observation.images."): if key.startswith("observation.images."):
img = obs[key] img = obs[key]
@@ -576,13 +576,13 @@ class WorkspaceCropper:
img_cropped = img[:, :, y1:y2, x1:x2] img_cropped = img[:, :, y1:y2, x1:x2]
# Resize to expected dimensions # Resize to expected dimensions
img_resized = F.interpolate( img_resized = F.interpolate(
img_cropped, img_cropped,
size=self.output_size, size=self.output_size,
mode='bilinear', mode='bilinear',
align_corners=False align_corners=False
) )
obs[key] = img_resized obs[key] = img_resized
return (obs, *transition[1:]) return (obs, *transition[1:])
``` ```
@@ -596,7 +596,7 @@ Now you can compose these steps into robot-specific pipelines:
robot_a_processor = RobotProcessor([ robot_a_processor = RobotProcessor([
# Observation preprocessing # Observation preprocessing
KeyRemapper({ KeyRemapper({
"wrist_rgb": "observation.images.wrist", "wrist_rgb": "observation.images.wrist",
"arm_joints": "observation.state" "arm_joints": "observation.state"
}), }),
ImageProcessor(), ImageProcessor(),
@@ -604,7 +604,7 @@ robot_a_processor = RobotProcessor([
StateProcessor(), StateProcessor(),
VelocityCalculator(state_key="observation.state"), VelocityCalculator(state_key="observation.state"),
DeviceMover("cuda"), DeviceMover("cuda"),
# Action postprocessing # Action postprocessing
ActionSmoother(alpha=0.2), ActionSmoother(alpha=0.2),
], name="RobotA_ACT_Processor") ], name="RobotA_ACT_Processor")
@@ -618,10 +618,10 @@ robot_b_processor = RobotProcessor([
}), }),
ImageProcessor(), ImageProcessor(),
WorkspaceCropper(crop_bbox=(100, 50, 1100, 950), output_size=(224, 224)), WorkspaceCropper(crop_bbox=(100, 50, 1100, 950), output_size=(224, 224)),
StateProcessor(), StateProcessor(),
VelocityCalculator(state_key="observation.state"), VelocityCalculator(state_key="observation.state"),
DeviceMover("cuda"), DeviceMover("cuda"),
ActionSmoother(alpha=0.3), # Different smoothing ActionSmoother(alpha=0.3), # Different smoothing
], name="RobotB_ACT_Processor") ], name="RobotB_ACT_Processor")
@@ -645,11 +645,11 @@ The beauty of this approach is that:
```python ```python
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
# Always check if observation exists # Always check if observation exists
if obs is None: if obs is None:
return transition return transition
# Also check for specific keys # Also check for specific keys
if "observation.state" not in obs: if "observation.state" not in obs:
return transition return transition
@@ -668,11 +668,11 @@ return (modified_obs, None, 0.0, False, False, {}, {})
```python ```python
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition[TransitionIndex.OBSERVATION] obs = transition[TransitionIndex.OBSERVATION]
if self.store_previous: if self.store_previous:
# Good - clone to avoid reference issues # Good - clone to avoid reference issues
self.previous_state = obs["observation.state"].clone() self.previous_state = obs["observation.state"].clone()
# Bad - stores reference that might be modified # Bad - stores reference that might be modified
# self.previous_state = obs["observation.state"] # self.previous_state = obs["observation.state"]
``` ```
@@ -682,7 +682,7 @@ def __call__(self, transition: EnvTransition) -> EnvTransition:
def state_dict(self) -> Dict[str, torch.Tensor]: def state_dict(self) -> Dict[str, torch.Tensor]:
if self.buffer is None: if self.buffer is None:
return {} return {}
# Good - clone to avoid memory sharing issues # Good - clone to avoid memory sharing issues
return {"buffer": self.buffer.clone()} return {"buffer": self.buffer.clone()}
``` ```
@@ -713,16 +713,16 @@ class ActionClipper:
"""Clip actions to safe ranges.""" """Clip actions to safe ranges."""
min_value: float = -1.0 min_value: float = -1.0
max_value: float = 1.0 max_value: float = 1.0
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition[TransitionIndex.ACTION] action = transition[TransitionIndex.ACTION]
if action is not None: if action is not None:
action = torch.clamp(action, self.min_value, self.max_value) action = torch.clamp(action, self.min_value, self.max_value)
return (transition[TransitionIndex.OBSERVATION], action, *transition[2:]) return (transition[TransitionIndex.OBSERVATION], action, *transition[2:])
return transition return transition
def get_config(self) -> Dict[str, Any]: def get_config(self) -> Dict[str, Any]:
return {"min_value": self.min_value, "max_value": self.max_value} return {"min_value": self.min_value, "max_value": self.max_value}
@@ -736,7 +736,7 @@ policy = ACTPolicy.from_pretrained("lerobot/act_aloha_sim_transfer_cube_human")
# Move everything to GPU # Move everything to GPU
preprocessor = preprocessor.to("cuda") preprocessor = preprocessor.to("cuda")
postprocessor = postprocessor.to("cuda") postprocessor = postprocessor.to("cuda")
policy = policy.to("cuda") policy = policy.to("cuda")
# Control loop # Control loop
@@ -745,42 +745,42 @@ obs, info = env.reset()
for episode in range(10): for episode in range(10):
print(f"Episode {episode + 1}") print(f"Episode {episode + 1}")
for step in range(1000): for step in range(1000):
# Create transition with raw observation # Create transition with raw observation
transition = (obs, None, 0.0, False, False, info, {"step": step}) transition = (obs, None, 0.0, False, False, info, {"step": step})
# Preprocess # Preprocess
processed_transition = preprocessor(transition) processed_transition = preprocessor(transition)
processed_obs = processed_transition[TransitionIndex.OBSERVATION] processed_obs = processed_transition[TransitionIndex.OBSERVATION]
# Get action from policy # Get action from policy
with torch.no_grad(): with torch.no_grad():
action = policy.select_action(processed_obs) action = policy.select_action(processed_obs)
# Postprocess action # Postprocess action
action_transition = ( action_transition = (
processed_obs, processed_obs,
action, action,
0.0, 0.0,
False, False,
False, False,
info, info,
{"raw_action": action.clone()} # Store raw action in complementary_data {"raw_action": action.clone()} # Store raw action in complementary_data
) )
processed_action_transition = postprocessor(action_transition) processed_action_transition = postprocessor(action_transition)
final_action = processed_action_transition[TransitionIndex.ACTION] final_action = processed_action_transition[TransitionIndex.ACTION]
# Execute action # Execute action
obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy()) obs, reward, terminated, truncated, info = env.step(final_action.cpu().numpy())
if terminated or truncated: if terminated or truncated:
# Reset at episode boundary # Reset at episode boundary
preprocessor.reset() preprocessor.reset()
postprocessor.reset() postprocessor.reset()
obs, info = env.reset() obs, info = env.reset()
break break
# Save preprocessor with learned statistics # Save preprocessor with learned statistics
preprocessor.save_pretrained(f"./checkpoints/preprocessor_ep{episode}") preprocessor.save_pretrained(f"./checkpoints/preprocessor_ep{episode}")
@@ -844,4 +844,4 @@ RobotProcessor provides a powerful, modular approach to data preprocessing in ro
By following these patterns, your preprocessing code becomes more maintainable, shareable, and robust. By following these patterns, your preprocessing code becomes more maintainable, shareable, and robust.
For the full API reference, see the [RobotProcessor API documentation](/api/processor). For the full API reference, see the [RobotProcessor API documentation](/api/processor).