[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-02 15:31:15 +00:00
committed by Adil Zouitine
parent f6c7287ae7
commit 769f531603
9 changed files with 485 additions and 475 deletions
+97 -88
View File
@@ -16,87 +16,86 @@
import json
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from dataclasses import dataclass
import numpy as np
import pytest
import torch
from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep
from lerobot.processor.pipeline import EnvTransition, RobotPipeline
@dataclass
class MockStep:
"""Mock pipeline step for testing - demonstrates best practices.
This example shows the proper separation:
- JSON-serializable attributes (name, counter) go in get_config()
- Only torch tensors go in state_dict()
Note: The counter is part of the configuration, so it will be restored
when the step is recreated from config during loading.
"""
name: str = "mock_step"
counter: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add a counter to the complementary_data."""
obs, action, reward, done, truncated, info, comp_data = transition
if comp_data is None:
comp_data = {}
else:
comp_data = dict(comp_data) # Make a copy
comp_data[f"{self.name}_counter"] = self.counter
self.counter += 1
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> dict[str, Any]:
# Return all JSON-serializable attributes that should be persisted
# These will be passed to __init__ when loading
return {"name": self.name, "counter": self.counter}
def state_dict(self) -> dict[str, torch.Tensor]:
# Only return torch tensors (empty in this case since we have no tensor state)
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
# No tensor state to load
pass
def reset(self) -> None:
self.counter = 0
@dataclass
@dataclass
class MockStepWithoutOptionalMethods:
"""Mock step that only implements the required __call__ method."""
multiplier: float = 2.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Multiply reward by multiplier."""
obs, action, reward, done, truncated, info, comp_data = transition
if reward is not None:
reward = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
@dataclass
class MockStepWithTensorState:
"""Mock step demonstrating mixed JSON attributes and tensor state."""
name: str = "tensor_step"
learning_rate: float = 0.01
window_size: int = 10
def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10):
self.name = name
self.learning_rate = learning_rate
@@ -104,19 +103,19 @@ class MockStepWithTensorState:
# Tensor state
self.running_mean = torch.zeros(window_size)
self.running_count = torch.tensor(0)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Update running statistics."""
obs, action, reward, done, truncated, info, comp_data = transition
if reward is not None:
# Update running mean
idx = self.running_count % self.window_size
self.running_mean[idx] = reward
self.running_count += 1
return transition
def get_config(self) -> dict[str, Any]:
# Only JSON-serializable attributes
return {
@@ -124,18 +123,18 @@ class MockStepWithTensorState:
"learning_rate": self.learning_rate,
"window_size": self.window_size,
}
def state_dict(self) -> dict[str, torch.Tensor]:
# Only tensor state
return {
"running_mean": self.running_mean,
"running_count": self.running_count,
}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
self.running_mean = state["running_mean"]
self.running_count = state["running_count"]
def reset(self) -> None:
self.running_mean.zero_()
self.running_count.zero_()
@@ -144,265 +143,275 @@ class MockStepWithTensorState:
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotPipeline()
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert result == transition
assert len(pipeline) == 0
def test_single_step_pipeline():
"""Test pipeline with a single step."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert len(pipeline) == 1
assert result[6]["test_step_counter"] == 0 # complementary_data
# Call again to test counter increment
result = pipeline(transition)
assert result[6]["test_step_counter"] == 1
def test_multiple_steps_pipeline():
"""Test pipeline with multiple steps."""
step1 = MockStep("step1")
step2 = MockStep("step2")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert len(pipeline) == 2
assert result[6]["step1_counter"] == 0
assert result[6]["step2_counter"] == 0
def test_invalid_transition_format():
"""Test pipeline with invalid transition format."""
pipeline = RobotPipeline([MockStep()])
# Test with wrong number of elements
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline((None, None, 0.0)) # Only 3 elements
# Test with wrong type
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline("not a tuple")
def test_step_through():
"""Test step_through method."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
results = list(pipeline.step_through(transition))
assert len(results) == 3 # Original + 2 steps
assert results[0] == transition # Original
assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2
def test_indexing():
"""Test pipeline indexing."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
# Test integer indexing
assert pipeline[0] is step1
assert pipeline[1] is step2
# Test slice indexing
sub_pipeline = pipeline[0:1]
assert isinstance(sub_pipeline, RobotPipeline)
assert len(sub_pipeline) == 1
assert sub_pipeline[0] is step1
def test_hooks():
"""Test before/after step hooks."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
before_calls = []
after_calls = []
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return transition
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return transition
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
transition = (None, None, 0.0, False, False, {}, {})
pipeline(transition)
assert before_calls == [0]
assert after_calls == [0]
def test_hook_modification():
"""Test that hooks can modify transitions."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
obs, action, reward, done, truncated, info, comp_data = transition
return (obs, action, 42.0, done, truncated, info, comp_data)
pipeline.register_before_step_hook(modify_reward_hook)
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert result[2] == 42.0 # reward modified by hook
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
reset_called = []
def reset_hook():
reset_called.append(True)
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = (None, None, 0.0, False, False, {}, {})
pipeline(transition)
pipeline(transition)
assert step.counter == 2
# Reset should reset step and call hook
pipeline.reset()
assert step.counter == 0
assert len(reset_called) == 1
def test_profile_steps():
"""Test step profiling functionality."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
profile_results = pipeline.profile_steps(transition, num_runs=10)
assert len(profile_results) == 2
assert "step_0_MockStep" in profile_results
assert "step_1_MockStep" in profile_results
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
def test_save_and_load_pretrained():
"""Test saving and loading pipeline.
This test demonstrates that JSON-serializable attributes (like counter)
are saved in the config and restored when the step is recreated.
"""
step1 = MockStep("step1")
step2 = MockStep("step2")
# Increment counters to have some state
step1.counter = 5
step2.counter = 10
pipeline = RobotPipeline([step1, step2], name="TestPipeline", seed=42)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "pipeline.json"
assert config_path.exists()
# Check config content
with open(config_path) as f:
config = json.load(f)
assert config["name"] == "TestPipeline"
assert config["seed"] == 42
assert len(config["steps"]) == 2
# Verify counters are saved in config, not in separate state files
assert config["steps"][0]["config"]["counter"] == 5
assert config["steps"][1]["config"]["counter"] == 10
# Load pipeline
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestPipeline"
assert loaded_pipeline.seed == 42
assert len(loaded_pipeline) == 2
# Check that counter was restored from config
assert loaded_pipeline.steps[0].counter == 5
assert loaded_pipeline.steps[1].counter == 10
def test_step_without_optional_methods():
"""Test pipeline with steps that don't implement optional methods."""
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotPipeline([step])
transition = (None, None, 2.0, False, False, {}, {})
result = pipeline(transition)
assert result[2] == 6.0 # 2.0 * 3.0
# Reset should work even if step doesn't implement reset
pipeline.reset()
# Save/load should work even without optional methods
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
assert len(loaded_pipeline) == 1
def test_mixed_json_and_tensor_state():
"""Test step with both JSON attributes and tensor state."""
step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5)
pipeline = RobotPipeline([step])
# Process some transitions with rewards
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
pipeline(transition)
# Check state
assert step.running_count.item() == 10
assert step.learning_rate == 0.05
# Save and load
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check that both config and state files were created
config_path = Path(tmp_dir) / "pipeline.json"
config_path = Path(tmp_dir) / "pipeline.json"
state_path = Path(tmp_dir) / "step_0.safetensors"
assert config_path.exists()
assert state_path.exists()
# Load and verify
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
loaded_step = loaded_pipeline.steps[0]
# Check JSON attributes were restored
assert loaded_step.name == "stats"
assert loaded_step.learning_rate == 0.05
assert loaded_step.window_size == 5
# Check tensor state was restored
assert loaded_step.running_count.item() == 10
assert torch.allclose(loaded_step.running_mean, step.running_mean)