mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
f6c7287ae7
commit
769f531603
@@ -16,87 +16,86 @@
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep
|
||||
from lerobot.processor.pipeline import EnvTransition, RobotPipeline
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStep:
|
||||
"""Mock pipeline step for testing - demonstrates best practices.
|
||||
|
||||
|
||||
This example shows the proper separation:
|
||||
- JSON-serializable attributes (name, counter) go in get_config()
|
||||
- Only torch tensors go in state_dict()
|
||||
|
||||
|
||||
Note: The counter is part of the configuration, so it will be restored
|
||||
when the step is recreated from config during loading.
|
||||
"""
|
||||
|
||||
|
||||
name: str = "mock_step"
|
||||
counter: int = 0
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Add a counter to the complementary_data."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
|
||||
if comp_data is None:
|
||||
comp_data = {}
|
||||
else:
|
||||
comp_data = dict(comp_data) # Make a copy
|
||||
|
||||
|
||||
comp_data[f"{self.name}_counter"] = self.counter
|
||||
self.counter += 1
|
||||
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Return all JSON-serializable attributes that should be persisted
|
||||
# These will be passed to __init__ when loading
|
||||
return {"name": self.name, "counter": self.counter}
|
||||
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
# Only return torch tensors (empty in this case since we have no tensor state)
|
||||
return {}
|
||||
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
# No tensor state to load
|
||||
pass
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class MockStepWithoutOptionalMethods:
|
||||
"""Mock step that only implements the required __call__ method."""
|
||||
|
||||
|
||||
multiplier: float = 2.0
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Multiply reward by multiplier."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
|
||||
if reward is not None:
|
||||
reward = reward * self.multiplier
|
||||
|
||||
|
||||
return (obs, action, reward, done, truncated, info, comp_data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStepWithTensorState:
|
||||
"""Mock step demonstrating mixed JSON attributes and tensor state."""
|
||||
|
||||
|
||||
name: str = "tensor_step"
|
||||
learning_rate: float = 0.01
|
||||
window_size: int = 10
|
||||
|
||||
|
||||
def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10):
|
||||
self.name = name
|
||||
self.learning_rate = learning_rate
|
||||
@@ -104,19 +103,19 @@ class MockStepWithTensorState:
|
||||
# Tensor state
|
||||
self.running_mean = torch.zeros(window_size)
|
||||
self.running_count = torch.tensor(0)
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Update running statistics."""
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
|
||||
|
||||
if reward is not None:
|
||||
# Update running mean
|
||||
idx = self.running_count % self.window_size
|
||||
self.running_mean[idx] = reward
|
||||
self.running_count += 1
|
||||
|
||||
|
||||
return transition
|
||||
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
# Only JSON-serializable attributes
|
||||
return {
|
||||
@@ -124,18 +123,18 @@ class MockStepWithTensorState:
|
||||
"learning_rate": self.learning_rate,
|
||||
"window_size": self.window_size,
|
||||
}
|
||||
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
# Only tensor state
|
||||
return {
|
||||
"running_mean": self.running_mean,
|
||||
"running_count": self.running_count,
|
||||
}
|
||||
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
self.running_mean = state["running_mean"]
|
||||
self.running_count = state["running_count"]
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
self.running_mean.zero_()
|
||||
self.running_count.zero_()
|
||||
@@ -144,265 +143,275 @@ class MockStepWithTensorState:
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = RobotPipeline()
|
||||
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
|
||||
|
||||
assert result == transition
|
||||
assert len(pipeline) == 0
|
||||
|
||||
|
||||
def test_single_step_pipeline():
|
||||
"""Test pipeline with a single step."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
|
||||
|
||||
assert len(pipeline) == 1
|
||||
assert result[6]["test_step_counter"] == 0 # complementary_data
|
||||
|
||||
|
||||
# Call again to test counter increment
|
||||
result = pipeline(transition)
|
||||
assert result[6]["test_step_counter"] == 1
|
||||
|
||||
|
||||
def test_multiple_steps_pipeline():
|
||||
"""Test pipeline with multiple steps."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
|
||||
|
||||
assert len(pipeline) == 2
|
||||
assert result[6]["step1_counter"] == 0
|
||||
assert result[6]["step2_counter"] == 0
|
||||
|
||||
|
||||
def test_invalid_transition_format():
|
||||
"""Test pipeline with invalid transition format."""
|
||||
pipeline = RobotPipeline([MockStep()])
|
||||
|
||||
|
||||
# Test with wrong number of elements
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
|
||||
pipeline((None, None, 0.0)) # Only 3 elements
|
||||
|
||||
|
||||
# Test with wrong type
|
||||
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
|
||||
pipeline("not a tuple")
|
||||
|
||||
|
||||
def test_step_through():
|
||||
"""Test step_through method."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
|
||||
|
||||
results = list(pipeline.step_through(transition))
|
||||
|
||||
|
||||
assert len(results) == 3 # Original + 2 steps
|
||||
assert results[0] == transition # Original
|
||||
assert "step1_counter" in results[1][6] # After step1
|
||||
assert "step2_counter" in results[2][6] # After step2
|
||||
|
||||
|
||||
def test_indexing():
|
||||
"""Test pipeline indexing."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
|
||||
|
||||
# Test integer indexing
|
||||
assert pipeline[0] is step1
|
||||
assert pipeline[1] is step2
|
||||
|
||||
|
||||
# Test slice indexing
|
||||
sub_pipeline = pipeline[0:1]
|
||||
assert isinstance(sub_pipeline, RobotPipeline)
|
||||
assert len(sub_pipeline) == 1
|
||||
assert sub_pipeline[0] is step1
|
||||
|
||||
|
||||
def test_hooks():
|
||||
"""Test before/after step hooks."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
|
||||
|
||||
before_calls = []
|
||||
after_calls = []
|
||||
|
||||
|
||||
def before_hook(idx: int, transition: EnvTransition):
|
||||
before_calls.append(idx)
|
||||
return transition
|
||||
|
||||
|
||||
def after_hook(idx: int, transition: EnvTransition):
|
||||
after_calls.append(idx)
|
||||
return transition
|
||||
|
||||
|
||||
pipeline.register_before_step_hook(before_hook)
|
||||
pipeline.register_after_step_hook(after_hook)
|
||||
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
pipeline(transition)
|
||||
|
||||
|
||||
assert before_calls == [0]
|
||||
assert after_calls == [0]
|
||||
|
||||
|
||||
def test_hook_modification():
|
||||
"""Test that hooks can modify transitions."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
|
||||
|
||||
def modify_reward_hook(idx: int, transition: EnvTransition):
|
||||
obs, action, reward, done, truncated, info, comp_data = transition
|
||||
return (obs, action, 42.0, done, truncated, info, comp_data)
|
||||
|
||||
|
||||
pipeline.register_before_step_hook(modify_reward_hook)
|
||||
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
|
||||
|
||||
assert result[2] == 42.0 # reward modified by hook
|
||||
|
||||
|
||||
def test_reset():
|
||||
"""Test pipeline reset functionality."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotPipeline([step])
|
||||
|
||||
|
||||
reset_called = []
|
||||
|
||||
|
||||
def reset_hook():
|
||||
reset_called.append(True)
|
||||
|
||||
|
||||
pipeline.register_reset_hook(reset_hook)
|
||||
|
||||
|
||||
# Make some calls to increment counter
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
pipeline(transition)
|
||||
pipeline(transition)
|
||||
|
||||
|
||||
assert step.counter == 2
|
||||
|
||||
|
||||
# Reset should reset step and call hook
|
||||
pipeline.reset()
|
||||
|
||||
|
||||
assert step.counter == 0
|
||||
assert len(reset_called) == 1
|
||||
|
||||
|
||||
def test_profile_steps():
|
||||
"""Test step profiling functionality."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotPipeline([step1, step2])
|
||||
|
||||
|
||||
transition = (None, None, 0.0, False, False, {}, {})
|
||||
|
||||
|
||||
profile_results = pipeline.profile_steps(transition, num_runs=10)
|
||||
|
||||
|
||||
assert len(profile_results) == 2
|
||||
assert "step_0_MockStep" in profile_results
|
||||
assert "step_1_MockStep" in profile_results
|
||||
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading pipeline.
|
||||
|
||||
|
||||
This test demonstrates that JSON-serializable attributes (like counter)
|
||||
are saved in the config and restored when the step is recreated.
|
||||
"""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
|
||||
|
||||
# Increment counters to have some state
|
||||
step1.counter = 5
|
||||
step2.counter = 10
|
||||
|
||||
|
||||
pipeline = RobotPipeline([step1, step2], name="TestPipeline", seed=42)
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
|
||||
# Check files were created
|
||||
config_path = Path(tmp_dir) / "pipeline.json"
|
||||
assert config_path.exists()
|
||||
|
||||
|
||||
# Check config content
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
assert config["name"] == "TestPipeline"
|
||||
assert config["seed"] == 42
|
||||
assert len(config["steps"]) == 2
|
||||
|
||||
|
||||
# Verify counters are saved in config, not in separate state files
|
||||
assert config["steps"][0]["config"]["counter"] == 5
|
||||
assert config["steps"][1]["config"]["counter"] == 10
|
||||
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
|
||||
|
||||
|
||||
assert loaded_pipeline.name == "TestPipeline"
|
||||
assert loaded_pipeline.seed == 42
|
||||
assert len(loaded_pipeline) == 2
|
||||
|
||||
|
||||
# Check that counter was restored from config
|
||||
assert loaded_pipeline.steps[0].counter == 5
|
||||
assert loaded_pipeline.steps[1].counter == 10
|
||||
|
||||
|
||||
def test_step_without_optional_methods():
|
||||
"""Test pipeline with steps that don't implement optional methods."""
|
||||
step = MockStepWithoutOptionalMethods(multiplier=3.0)
|
||||
pipeline = RobotPipeline([step])
|
||||
|
||||
|
||||
transition = (None, None, 2.0, False, False, {}, {})
|
||||
result = pipeline(transition)
|
||||
|
||||
|
||||
assert result[2] == 6.0 # 2.0 * 3.0
|
||||
|
||||
|
||||
# Reset should work even if step doesn't implement reset
|
||||
pipeline.reset()
|
||||
|
||||
|
||||
# Save/load should work even without optional methods
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
|
||||
def test_mixed_json_and_tensor_state():
|
||||
"""Test step with both JSON attributes and tensor state."""
|
||||
step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5)
|
||||
pipeline = RobotPipeline([step])
|
||||
|
||||
|
||||
# Process some transitions with rewards
|
||||
for i in range(10):
|
||||
transition = (None, None, float(i), False, False, {}, {})
|
||||
pipeline(transition)
|
||||
|
||||
|
||||
# Check state
|
||||
assert step.running_count.item() == 10
|
||||
assert step.learning_rate == 0.05
|
||||
|
||||
|
||||
# Save and load
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
|
||||
# Check that both config and state files were created
|
||||
config_path = Path(tmp_dir) / "pipeline.json"
|
||||
config_path = Path(tmp_dir) / "pipeline.json"
|
||||
state_path = Path(tmp_dir) / "step_0.safetensors"
|
||||
assert config_path.exists()
|
||||
assert state_path.exists()
|
||||
|
||||
|
||||
# Load and verify
|
||||
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
|
||||
# Check JSON attributes were restored
|
||||
assert loaded_step.name == "stats"
|
||||
assert loaded_step.learning_rate == 0.05
|
||||
assert loaded_step.window_size == 5
|
||||
|
||||
|
||||
# Check tensor state was restored
|
||||
assert loaded_step.running_count.item() == 10
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user