mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fa3eb9fce3 | |||
| 500c91ba92 | |||
| 49755a3d9e | |||
| 09808183ca |
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
|||||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||||
|
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||||
|
|
||||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||||
|
|||||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
|||||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||||
| `--teleop.type` | **Required.** Teleoperator type |
|
| `--teleop.type` | **Required.** Teleoperator type |
|
||||||
|
|
||||||
|
### Episodic (`--strategy.type=episodic`)
|
||||||
|
|
||||||
|
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=episodic \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--teleop.type=so100_leader \
|
||||||
|
--teleop.port=/dev/ttyACM1 \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||||
|
--dataset.num_episodes=20 \
|
||||||
|
--dataset.episode_time_s=30 \
|
||||||
|
--dataset.reset_time_s=10 \
|
||||||
|
--dataset.single_task="Pick up the red cube"
|
||||||
|
```
|
||||||
|
|
||||||
|
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||||
|
|
||||||
|
**Keyboard controls:**
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| ----------- | -------------------------------- |
|
||||||
|
| `→` (right) | End the current episode early |
|
||||||
|
| `←` (left) | Discard episode and re-record it |
|
||||||
|
| `ESC` | Stop the recording session |
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `--dataset.num_episodes` | Number of episodes to record |
|
||||||
|
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||||
|
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||||
|
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||||
|
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||||
|
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Inference Backends
|
## Inference Backends
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ groot = [
|
|||||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||||
topreward = ["lerobot[transformers-dep]"]
|
topreward = ["lerobot[transformers-dep]"]
|
||||||
|
recap = ["lerobot[transformers-dep]"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
@@ -296,6 +297,7 @@ all = [
|
|||||||
"lerobot[sarm]",
|
"lerobot[sarm]",
|
||||||
"lerobot[robometer]",
|
"lerobot[robometer]",
|
||||||
"lerobot[topreward]",
|
"lerobot[topreward]",
|
||||||
|
"lerobot[recap]",
|
||||||
"lerobot[peft]",
|
"lerobot[peft]",
|
||||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
|||||||
# Utilities
|
# Utilities
|
||||||
########################################################################################
|
########################################################################################
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
# Teleoperator smooth handover helpers
|
||||||
|
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||||
|
# being a root module.
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_supports_feedback(teleop) -> bool:
|
||||||
|
"""Return True when the teleop can receive position feedback (is actuated).
|
||||||
|
|
||||||
|
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||||
|
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||||
|
|
||||||
|
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
bool(teleop.feedback_features)
|
||||||
|
and hasattr(teleop, "disable_torque")
|
||||||
|
and hasattr(teleop, "enable_torque")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||||
|
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||||
|
|
||||||
|
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||||
|
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||||
|
|
||||||
|
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||||
|
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||||
|
the robot action key space directly.
|
||||||
|
|
||||||
|
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||||
|
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||||
|
"""
|
||||||
|
teleop.enable_torque()
|
||||||
|
current = teleop.get_action()
|
||||||
|
steps = max(int(duration_s * fps), 1)
|
||||||
|
|
||||||
|
for step in range(steps + 1):
|
||||||
|
t = step / steps
|
||||||
|
interp = {
|
||||||
|
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||||
|
}
|
||||||
|
teleop.send_feedback(interp)
|
||||||
|
time.sleep(1 / fps)
|
||||||
|
|
||||||
|
|
||||||
|
def follower_smooth_move_to(
|
||||||
|
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||||
|
) -> None:
|
||||||
|
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||||
|
|
||||||
|
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||||
|
the follower, the follower is brought to the teleop's current pose so the
|
||||||
|
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||||
|
|
||||||
|
Both ``current`` and ``target`` must be in the robot action key space
|
||||||
|
(i.e. the output of ``robot_action_processor``).
|
||||||
|
"""
|
||||||
|
steps = max(int(duration_s * fps), 1)
|
||||||
|
|
||||||
|
for step in range(steps + 1):
|
||||||
|
t = step / steps
|
||||||
|
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||||
|
robot.send_action(interp)
|
||||||
|
time.sleep(1 / fps)
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
|
|
||||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||||
|
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||||
|
default=None,
|
||||||
|
init=False,
|
||||||
|
repr=False,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, data: TInput) -> TOutput:
|
def __call__(self, data: TInput) -> TOutput:
|
||||||
"""Processes input data through the full pipeline.
|
"""Processes input data through the full pipeline.
|
||||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
transition = processor_step(transition)
|
transition = processor_step(transition)
|
||||||
yield transition
|
yield transition
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
def _get_sanitized_name(self) -> str:
|
||||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
"""Return a filename-safe version of the pipeline name.
|
||||||
|
|
||||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
Returns:
|
||||||
|
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||||
"""
|
"""
|
||||||
config_filename = kwargs.pop("config_filename", None)
|
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||||
|
|
||||||
# Sanitize the pipeline name to create a valid filename prefix.
|
@staticmethod
|
||||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
def _get_state_filename(
|
||||||
|
*,
|
||||||
|
step_index: int,
|
||||||
|
registry_name: str | None,
|
||||||
|
sanitized_name: str,
|
||||||
|
) -> str:
|
||||||
|
"""Return the safetensors filename for one stateful processor step.
|
||||||
|
|
||||||
if config_filename is None:
|
Args:
|
||||||
config_filename = f"{sanitized_name}.json"
|
step_index: The index of the processor step in this pipeline.
|
||||||
|
registry_name: The registered processor step name, if available.
|
||||||
|
sanitized_name: The filename-safe pipeline name.
|
||||||
|
|
||||||
config: dict[str, Any] = {
|
Returns:
|
||||||
|
The state filename used by the existing disk serialization format.
|
||||||
|
"""
|
||||||
|
if registry_name:
|
||||||
|
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||||
|
|
||||||
|
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_state_key(state_filename: str) -> str:
|
||||||
|
"""Return the in-memory state key for a serialized state filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_filename: The `.safetensors` filename from the serialized config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state key used by the in-memory pipeline state dictionary.
|
||||||
|
"""
|
||||||
|
return state_filename.removesuffix(".safetensors")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||||
|
"""Return serialized state filenames in step order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loaded_config: A validated processor pipeline config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||||
|
"""
|
||||||
|
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||||
|
|
||||||
|
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||||
|
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||||
|
current non-empty step state.
|
||||||
|
"""
|
||||||
|
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||||
|
self.steps
|
||||||
|
):
|
||||||
|
return self._serialized_state_filenames
|
||||||
|
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
state_filenames: list[str | None] = []
|
||||||
|
|
||||||
|
for step_index, processor_step in enumerate(self.steps):
|
||||||
|
step_state_dict = processor_step.state_dict()
|
||||||
|
if not step_state_dict:
|
||||||
|
state_filenames.append(None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||||
|
state_filenames.append(
|
||||||
|
self._get_state_filename(
|
||||||
|
step_index=step_index,
|
||||||
|
registry_name=registry_name,
|
||||||
|
sanitized_name=sanitized_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(state_filenames)
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return the JSON-serializable pipeline configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||||
|
"""
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
pipeline_config: dict[str, Any] = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"steps": [],
|
"steps": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Iterate through each step to build its configuration entry.
|
|
||||||
for step_index, processor_step in enumerate(self.steps):
|
for step_index, processor_step in enumerate(self.steps):
|
||||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||||
|
|
||||||
step_entry: dict[str, Any] = {}
|
step_entry: dict[str, Any] = {}
|
||||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
|
||||||
if registry_name:
|
if registry_name:
|
||||||
step_entry["registry_name"] = registry_name
|
step_entry["registry_name"] = registry_name
|
||||||
else:
|
else:
|
||||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save step configuration if `get_config` is implemented.
|
step_entry["config"] = processor_step.get_config()
|
||||||
if hasattr(processor_step, "get_config"):
|
|
||||||
step_entry["config"] = processor_step.get_config()
|
|
||||||
|
|
||||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
step_state_dict = processor_step.state_dict()
|
||||||
if hasattr(processor_step, "state_dict"):
|
if step_state_dict:
|
||||||
state = processor_step.state_dict()
|
step_entry["state_file"] = self._get_state_filename(
|
||||||
if state:
|
step_index=step_index,
|
||||||
# Clone tensors to avoid modifying the original state.
|
registry_name=registry_name,
|
||||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
sanitized_name=sanitized_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a unique filename for the state file.
|
pipeline_config["steps"].append(step_entry)
|
||||||
if registry_name:
|
|
||||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
|
||||||
else:
|
|
||||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
|
||||||
|
|
||||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
return pipeline_config
|
||||||
step_entry["state_file"] = state_filename
|
|
||||||
|
|
||||||
config["steps"].append(step_entry)
|
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
|
"""Return pipeline state tensors grouped by state key.
|
||||||
|
|
||||||
# Write the main configuration JSON file.
|
Returns:
|
||||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||||
json.dump(config, file_pointer, indent=2)
|
"""
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
for step_index, processor_step in enumerate(self.steps):
|
||||||
|
step_state_dict = processor_step.state_dict()
|
||||||
|
if not step_state_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||||
|
state_filename = self._get_state_filename(
|
||||||
|
step_index=step_index,
|
||||||
|
registry_name=registry_name,
|
||||||
|
sanitized_name=sanitized_name,
|
||||||
|
)
|
||||||
|
state_key = self._get_state_key(state_filename)
|
||||||
|
pipeline_state_dict[state_key] = {
|
||||||
|
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return pipeline_state_dict
|
||||||
|
|
||||||
|
def load_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||||
|
) -> None:
|
||||||
|
"""Load pipeline state tensors into the existing steps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||||
|
"""
|
||||||
|
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||||
|
used_state_keys: set[str] = set()
|
||||||
|
|
||||||
|
for step_index, (processor_step, state_filename) in enumerate(
|
||||||
|
zip(self.steps, expected_state_filenames, strict=True)
|
||||||
|
):
|
||||||
|
if state_filename is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_key = self._get_state_key(state_filename)
|
||||||
|
if state_key not in state_dict:
|
||||||
|
raise KeyError(
|
||||||
|
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||||
|
f"Available state keys: {sorted(state_dict.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
processor_step.load_state_dict(state_dict[state_key])
|
||||||
|
used_state_keys.add(state_key)
|
||||||
|
|
||||||
|
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||||
|
if unexpected_state_keys:
|
||||||
|
expected_state_key_set = {
|
||||||
|
self._get_state_key(state_filename)
|
||||||
|
for state_filename in expected_state_filenames
|
||||||
|
if state_filename is not None
|
||||||
|
}
|
||||||
|
raise KeyError(
|
||||||
|
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||||
|
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||||
|
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||||
|
|
||||||
|
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||||
|
"""
|
||||||
|
config_filename = kwargs.pop("config_filename", None)
|
||||||
|
sanitized_name = self._get_sanitized_name()
|
||||||
|
|
||||||
|
if config_filename is None:
|
||||||
|
config_filename = f"{sanitized_name}.json"
|
||||||
|
|
||||||
|
pipeline_config = self.get_config()
|
||||||
|
pipeline_state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||||
|
state_filename = f"{state_key}.safetensors"
|
||||||
|
save_file(step_state_dict, save_directory / state_filename)
|
||||||
|
|
||||||
|
with open(save_directory / config_filename, "w") as file_pointer:
|
||||||
|
json.dump(pipeline_config, file_pointer, indent=2)
|
||||||
|
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||||
|
|
||||||
# 5. Construct and return the final pipeline instance
|
# 5. Construct and return the final pipeline instance
|
||||||
return cls(
|
pipeline = cls(
|
||||||
steps=steps,
|
steps=steps,
|
||||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||||
)
|
)
|
||||||
|
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: dict[str, Any],
|
||||||
|
*,
|
||||||
|
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
overrides: dict[str, Any] | None = None,
|
||||||
|
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||||
|
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||||
|
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||||
|
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: A config dictionary with the same structure as the saved processor JSON.
|
||||||
|
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||||
|
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||||
|
to_transition: Optional converter from input data to `EnvTransition`.
|
||||||
|
to_output: Optional converter from `EnvTransition` to output data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A processor pipeline built from the config and optional state.
|
||||||
|
"""
|
||||||
|
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||||
|
|
||||||
|
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||||
|
cls._validate_overrides_used(remaining_override_keys, config)
|
||||||
|
|
||||||
|
pipeline = cls(
|
||||||
|
steps=steps,
|
||||||
|
name=config.get("name", "DataProcessorPipeline"),
|
||||||
|
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||||
|
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||||
|
)
|
||||||
|
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||||
|
|
||||||
|
if state_dict is not None:
|
||||||
|
pipeline.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_config(
|
def _load_config(
|
||||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_loaded_config(
|
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
|
||||||
) -> None:
|
|
||||||
"""Validate that a config was loaded and is a valid processor config.
|
"""Validate that a config was loaded and is a valid processor config.
|
||||||
|
|
||||||
This method validates processor config format with intelligent migration detection:
|
This method validates processor config format with intelligent migration detection:
|
||||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: The model identifier (used for migration detection)
|
model_id: The model identifier (used for migration detection)
|
||||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
loaded_config: The loaded config value to validate (may be non-dict)
|
||||||
config_filename: The config filename that was loaded (for error messages)
|
config_filename: The config filename that was loaded (for error messages)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
model_id,
|
model_id,
|
||||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||||
)
|
)
|
||||||
|
loaded_config_description = (
|
||||||
|
list(loaded_config.keys())
|
||||||
|
if isinstance(loaded_config, dict)
|
||||||
|
else type(loaded_config).__name__
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
ImportError: If a step class cannot be imported or found in registry
|
ImportError: If a step class cannot be imported or found in registry
|
||||||
ValueError: If a step cannot be instantiated with its configuration
|
ValueError: If a step cannot be instantiated with its configuration
|
||||||
"""
|
"""
|
||||||
steps: list[ProcessorStep] = []
|
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||||
override_keys = set(overrides.keys())
|
|
||||||
|
|
||||||
for step_entry in loaded_config["steps"]:
|
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||||
# 1. Get step class and key
|
|
||||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
|
||||||
|
|
||||||
# 2. Instantiate step with overrides
|
|
||||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
|
||||||
|
|
||||||
# 3. Load step state if available
|
|
||||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||||
|
|
||||||
# 4. Track used overrides
|
return steps, remaining_override_keys
|
||||||
if step_key in override_keys:
|
|
||||||
override_keys.discard(step_key)
|
|
||||||
|
|
||||||
steps.append(step_instance)
|
@classmethod
|
||||||
|
def _build_steps_from_config(
|
||||||
|
cls,
|
||||||
|
loaded_config: dict[str, Any],
|
||||||
|
overrides: dict[str, Any],
|
||||||
|
) -> tuple[list[ProcessorStep], set[str]]:
|
||||||
|
"""Build processor steps from config without loading tensor state.
|
||||||
|
|
||||||
return steps, override_keys
|
Args:
|
||||||
|
loaded_config: The loaded processor configuration.
|
||||||
|
overrides: User-provided constructor overrides keyed by step key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing instantiated steps and override keys that did not match a step.
|
||||||
|
"""
|
||||||
|
processor_steps: list[ProcessorStep] = []
|
||||||
|
remaining_override_keys = set(overrides.keys())
|
||||||
|
|
||||||
|
for step_entry in loaded_config["steps"]:
|
||||||
|
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||||
|
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||||
|
|
||||||
|
if step_key in remaining_override_keys:
|
||||||
|
remaining_override_keys.discard(step_key)
|
||||||
|
|
||||||
|
processor_steps.append(processor_step)
|
||||||
|
|
||||||
|
return processor_steps, remaining_override_keys
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _is_processor_config(cls, config: dict) -> bool:
|
def _is_processor_config(cls, config: Any) -> bool:
|
||||||
"""Check if config follows DataProcessorPipeline format.
|
"""Check if config follows DataProcessorPipeline format.
|
||||||
|
|
||||||
This method validates the processor configuration structure:
|
This method validates the processor configuration structure:
|
||||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
# Must have a "steps" field with a list of step configurations
|
# Must have a "steps" field with a list of step configurations
|
||||||
if not isinstance(config.get("steps"), list):
|
if not isinstance(config.get("steps"), list):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -13,6 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||||
|
from .distributional_value_function.configuration_distributional_value_function import (
|
||||||
|
DistributionalVFConfig as DistributionalVFConfig,
|
||||||
|
)
|
||||||
from .factory import (
|
from .factory import (
|
||||||
get_reward_model_class as get_reward_model_class,
|
get_reward_model_class as get_reward_model_class,
|
||||||
make_reward_model as make_reward_model,
|
make_reward_model as make_reward_model,
|
||||||
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Configuration classes
|
# Configuration classes
|
||||||
|
"DistributionalVFConfig",
|
||||||
"RewardClassifierConfig",
|
"RewardClassifierConfig",
|
||||||
"RobometerConfig",
|
"RobometerConfig",
|
||||||
"SARMConfig",
|
"SARMConfig",
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||||
|
from .modeling_distributional_value_function import DistributionalVFRewardModel
|
||||||
|
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DistributionalVFConfig",
|
||||||
|
"DistributionalVFRewardModel",
|
||||||
|
"make_distributional_vf_pre_post_processors",
|
||||||
|
]
|
||||||
+108
@@ -0,0 +1,108 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Configuration for RECAP's distributional value function.
|
||||||
|
|
||||||
|
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||||
|
https://pi.website/blog/pistar06
|
||||||
|
|
||||||
|
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||||
|
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||||
|
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||||
|
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||||
|
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||||
|
|
||||||
|
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||||
|
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||||
|
Weights initialized from a pre-trained PI05 actor checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.configs import FeatureType, NormalizationMode
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@RewardModelConfig.register_subclass("distributional_value_function")
|
||||||
|
@dataclass
|
||||||
|
class DistributionalVFConfig(RewardModelConfig):
|
||||||
|
"""Configuration for RECAP's distributional value function.
|
||||||
|
|
||||||
|
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
|
||||||
|
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
|
||||||
|
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
|
||||||
|
(Eq. 1 in the paper).
|
||||||
|
|
||||||
|
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
|
||||||
|
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
|
||||||
|
about 670M params and initialize from the PI05 actor checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Backbone
|
||||||
|
paligemma_variant: str = "gemma_2b"
|
||||||
|
num_hidden_layers: int = 6
|
||||||
|
num_vision_layers: int = 13
|
||||||
|
|
||||||
|
# Distributional head
|
||||||
|
num_value_bins: int = 201
|
||||||
|
value_support_min: float = -1.0
|
||||||
|
value_support_max: float = 0.0
|
||||||
|
hl_gauss_sigma_ratio: float = 5.0
|
||||||
|
|
||||||
|
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
|
||||||
|
target_method: str = "hl_gauss"
|
||||||
|
|
||||||
|
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
|
||||||
|
# When False, terminal states use the same target method as non-terminal states.
|
||||||
|
use_one_hot_terminal: bool = True
|
||||||
|
|
||||||
|
# Image
|
||||||
|
image_resolution: tuple[int, int] = (224, 224)
|
||||||
|
|
||||||
|
# Tokenizer
|
||||||
|
tokenizer_max_length: int = 64
|
||||||
|
|
||||||
|
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
|
||||||
|
# Pass a PI05 checkpoint path or Hub repo_id here.
|
||||||
|
# After training, load the value function with RewardModel.from_pretrained() instead.
|
||||||
|
init_from_actor_path: str = ""
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=3e-4,
|
||||||
|
weight_decay=1e-4,
|
||||||
|
grad_clip_norm=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
num_warmup_steps=500,
|
||||||
|
num_decay_steps=50000,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
if not self.input_features:
|
||||||
|
return
|
||||||
|
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
|
||||||
|
if not has_image:
|
||||||
|
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
|
||||||
+567
@@ -0,0 +1,567 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Modeling for RECAP's distributional value function.
|
||||||
|
|
||||||
|
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||||
|
https://pi.website/blog/pistar06
|
||||||
|
|
||||||
|
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||||
|
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||||
|
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||||
|
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||||
|
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||||
|
|
||||||
|
Inputs: single image observation + task text prompt ("Task: {task}.")
|
||||||
|
Outputs: softmax distribution over value bins; expected value E[V] for inference.
|
||||||
|
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||||
|
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||||
|
|
||||||
|
Weight initialization: vision tower, multi-modal projector, token embeddings, and
|
||||||
|
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.models.auto import CONFIG_MAPPING
|
||||||
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
|
from lerobot.policies.pi_gemma import (
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||||
|
PiGemmaRMSNorm,
|
||||||
|
_gated_residual,
|
||||||
|
_get_pi_gemma_decoder_layer_base,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
CONFIG_MAPPING = None
|
||||||
|
modeling_gemma = None
|
||||||
|
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||||
|
PiGemmaRMSNorm = None
|
||||||
|
_gated_residual = None
|
||||||
|
_get_pi_gemma_decoder_layer_base = None
|
||||||
|
|
||||||
|
PALIGEMMA_VOCAB_SIZE = 257152
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionalVFRewardModel(PreTrainedRewardModel):
|
||||||
|
"""Distributional value function model for RECAP.
|
||||||
|
|
||||||
|
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
|
||||||
|
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
|
||||||
|
per-task normalized Monte Carlo returns.
|
||||||
|
|
||||||
|
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
|
||||||
|
causal attention, [CLS] token, and Linear(D, num_bins) value head.
|
||||||
|
The expected value is E[V] = sum(softmax(logits) * bin_centers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "distributional_value_function"
|
||||||
|
config_class = DistributionalVFConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
|
||||||
|
require_package("transformers", extra="recap")
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
|
||||||
|
|
||||||
|
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
|
||||||
|
|
||||||
|
# Get base dimensions from the paligemma variant (OpenPI config format)
|
||||||
|
base_config = get_gemma_config(config.paligemma_variant)
|
||||||
|
hidden_dim = base_config.width
|
||||||
|
mlp_dim = base_config.mlp_dim
|
||||||
|
num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
# HuggingFace GemmaConfig for transformer layers
|
||||||
|
gemma_config = CONFIG_MAPPING["gemma"](
|
||||||
|
head_dim=base_config.head_dim,
|
||||||
|
hidden_size=hidden_dim,
|
||||||
|
intermediate_size=mlp_dim,
|
||||||
|
num_attention_heads=base_config.num_heads,
|
||||||
|
num_hidden_layers=num_layers,
|
||||||
|
num_key_value_heads=base_config.num_kv_heads,
|
||||||
|
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||||
|
hidden_activation="gelu_pytorch_tanh",
|
||||||
|
)
|
||||||
|
self.gemma_config = gemma_config
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_value_bins = config.num_value_bins
|
||||||
|
|
||||||
|
# Single learned [CLS] token for value prediction
|
||||||
|
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
||||||
|
|
||||||
|
# Value projection head: Linear(hidden_dim, num_bins)
|
||||||
|
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
|
||||||
|
|
||||||
|
# Transformer layers (overwritten by _initialize_from_actor on first run)
|
||||||
|
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
|
||||||
|
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
|
||||||
|
)
|
||||||
|
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
|
||||||
|
|
||||||
|
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
|
||||||
|
# PaliGemmaConfig wraps both vision and text configs into a single model
|
||||||
|
paligemma_config = CONFIG_MAPPING["paligemma"]()
|
||||||
|
paligemma_config.text_config = gemma_config
|
||||||
|
paligemma_config.vision_config.image_size = config.image_resolution[0]
|
||||||
|
paligemma_config.vision_config.intermediate_size = 4304
|
||||||
|
paligemma_config.vision_config.projection_dim = 2048
|
||||||
|
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
|
|
||||||
|
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
|
||||||
|
self.vision_tower = paligemma_full.model.vision_tower
|
||||||
|
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
|
||||||
|
self.token_embedding = paligemma_full.model.language_model.embed_tokens
|
||||||
|
del paligemma_full
|
||||||
|
|
||||||
|
# Truncate vision tower to num_vision_layers
|
||||||
|
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
|
||||||
|
vision_encoder = self.vision_tower.vision_model.encoder
|
||||||
|
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
|
||||||
|
|
||||||
|
# Bin support: evenly spaced centers from value_support_min to value_support_max
|
||||||
|
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
|
||||||
|
self.register_buffer("bin_centers", bin_centers, persistent=False)
|
||||||
|
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
|
||||||
|
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
|
||||||
|
|
||||||
|
# Overwrite with pre-trained PI05 actor weights (first training run only)
|
||||||
|
if config.init_from_actor_path:
|
||||||
|
self._initialize_from_actor()
|
||||||
|
|
||||||
|
def _initialize_from_actor(self) -> None:
|
||||||
|
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
|
||||||
|
|
||||||
|
Called on first training run only (when init_from_actor_path is set).
|
||||||
|
"""
|
||||||
|
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||||
|
|
||||||
|
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
|
||||||
|
actor_model = actor_policy.model
|
||||||
|
|
||||||
|
paligemma_model = actor_model.paligemma_with_expert.paligemma
|
||||||
|
source_language_model = paligemma_model.model.language_model
|
||||||
|
|
||||||
|
# Transformer components
|
||||||
|
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
|
||||||
|
num_layers = self.gemma_config.num_hidden_layers
|
||||||
|
for i in range(num_layers):
|
||||||
|
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
|
||||||
|
self.norm.load_state_dict(source_language_model.norm.state_dict())
|
||||||
|
|
||||||
|
# Vision tower (truncate source first, then copy)
|
||||||
|
source_vision_tower = paligemma_model.model.vision_tower
|
||||||
|
if hasattr(source_vision_tower, "vision_model") and hasattr(
|
||||||
|
source_vision_tower.vision_model, "encoder"
|
||||||
|
):
|
||||||
|
source_encoder = source_vision_tower.vision_model.encoder
|
||||||
|
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
|
||||||
|
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
|
||||||
|
|
||||||
|
# Multi-modal projector
|
||||||
|
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
|
||||||
|
|
||||||
|
# Token embedding table
|
||||||
|
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
|
||||||
|
|
||||||
|
del actor_policy
|
||||||
|
|
||||||
|
def embed_image(self, image: Tensor) -> Tensor:
|
||||||
|
"""Embed images using the value function's SigLIP vision tower.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, num_patches, hidden_dim] projected image features.
|
||||||
|
"""
|
||||||
|
out_dtype = image.dtype
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
|
||||||
|
image_outputs = self.vision_tower(image, return_dict=True)
|
||||||
|
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||||
|
image_features = image_features / (self.hidden_dim**0.5)
|
||||||
|
|
||||||
|
if image_features.dtype != out_dtype:
|
||||||
|
image_features = image_features.to(out_dtype)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def embed_text(self, token_ids: Tensor) -> Tensor:
|
||||||
|
"""Embed text token IDs using the value function's token embedding table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: [batch_size, seq_len] integer token IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, seq_len, hidden_dim] text embeddings
|
||||||
|
"""
|
||||||
|
return self.token_embedding(token_ids)
|
||||||
|
|
||||||
|
def _get_cls_embedding(self, batch_size: int) -> Tensor:
|
||||||
|
"""Get [CLS] token embedding expanded to batch size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: number of samples in the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, 1, hidden_dim] learned [CLS] embedding.
|
||||||
|
"""
|
||||||
|
return self.cls_embedding.expand(batch_size, -1, -1)
|
||||||
|
|
||||||
|
def forward_value(
|
||||||
|
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
"""Core forward pass through the distributional value function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_features: [batch_size, num_patches, hidden_dim]
|
||||||
|
text_embeddings: [batch_size, seq_len, hidden_dim]
|
||||||
|
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logits: [batch_size, num_value_bins]
|
||||||
|
probs: [batch_size, num_value_bins]
|
||||||
|
value: [batch_size, 1]
|
||||||
|
"""
|
||||||
|
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
|
||||||
|
|
||||||
|
batch_size = text_embeddings.shape[0]
|
||||||
|
device = text_embeddings.device
|
||||||
|
|
||||||
|
# Build sequence: [vision, text, CLS]
|
||||||
|
cls_embedding = self._get_cls_embedding(batch_size)
|
||||||
|
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
|
||||||
|
|
||||||
|
# Build causal attention mask
|
||||||
|
vision_len = vision_features.shape[1]
|
||||||
|
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
|
||||||
|
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
|
||||||
|
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
|
||||||
|
|
||||||
|
full_seq_len = full_padding_mask.shape[1]
|
||||||
|
|
||||||
|
# Causal mask
|
||||||
|
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
|
||||||
|
# Combine causal mask with padding mask
|
||||||
|
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
|
||||||
|
batch_size, 1, full_seq_len, full_seq_len
|
||||||
|
)
|
||||||
|
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
|
||||||
|
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||||
|
|
||||||
|
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
|
||||||
|
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
norm_output = layer.input_layernorm(hidden_states, cond=None)
|
||||||
|
if isinstance(norm_output, tuple):
|
||||||
|
hidden_states_normed, gate = norm_output
|
||||||
|
else:
|
||||||
|
hidden_states_normed, gate = norm_output, None
|
||||||
|
|
||||||
|
input_shape = hidden_states_normed.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||||
|
|
||||||
|
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||||
|
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||||
|
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_output, _ = modeling_gemma.eager_attention_forward(
|
||||||
|
layer.self_attn,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
layer.self_attn.scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
|
||||||
|
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||||
|
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||||
|
projected_attention = layer.self_attn.o_proj(attention_output)
|
||||||
|
|
||||||
|
if gate is not None:
|
||||||
|
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
|
||||||
|
else:
|
||||||
|
projected_attention = hidden_states + projected_attention
|
||||||
|
|
||||||
|
after_attention_residual = projected_attention.clone()
|
||||||
|
|
||||||
|
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
|
||||||
|
if isinstance(norm_output, tuple):
|
||||||
|
mlp_input, gate = norm_output
|
||||||
|
else:
|
||||||
|
mlp_input, gate = norm_output, None
|
||||||
|
|
||||||
|
mlp_output = layer.mlp(mlp_input)
|
||||||
|
|
||||||
|
if gate is not None:
|
||||||
|
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
|
||||||
|
else:
|
||||||
|
hidden_states = after_attention_residual + mlp_output
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
|
||||||
|
# Extract [CLS] token (last position in the sequence)
|
||||||
|
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
|
||||||
|
|
||||||
|
# Value head: Linear(hidden_dim, num_bins) -> logits
|
||||||
|
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
|
||||||
|
value_probs = F.softmax(value_logits, dim=-1)
|
||||||
|
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
|
||||||
|
dim=-1, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
|
||||||
|
|
||||||
|
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
|
||||||
|
"""HL-Gauss soft target distribution.
|
||||||
|
|
||||||
|
Places a Gaussian N(target, sigma^2) over the bin support and computes
|
||||||
|
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
|
||||||
|
|
||||||
|
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
|
||||||
|
Functions via Classification for Scalable Deep RL", Section 3.1.
|
||||||
|
arXiv:2403.03950
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_value: [batch_size] or [batch_size, 1] target values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, num_value_bins] target probability distribution.
|
||||||
|
"""
|
||||||
|
if target_value.ndim == 2:
|
||||||
|
target_value = target_value.squeeze(-1)
|
||||||
|
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||||
|
|
||||||
|
# Bin edges: half a bin-width outside the first/last center
|
||||||
|
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
|
||||||
|
self.num_value_bins - 1
|
||||||
|
)
|
||||||
|
support_edges = torch.linspace(
|
||||||
|
self.config.value_support_min - bin_width / 2,
|
||||||
|
self.config.value_support_max + bin_width / 2,
|
||||||
|
self.num_value_bins + 1,
|
||||||
|
device=target_value.device,
|
||||||
|
dtype=target_value.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CDF of N(target, sigma^2) evaluated at each edge
|
||||||
|
cdf_at_edges = 0.5 * (
|
||||||
|
1.0
|
||||||
|
+ torch.erf(
|
||||||
|
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
|
||||||
|
/ (self.hl_gauss_sigma * math.sqrt(2))
|
||||||
|
)
|
||||||
|
) # [batch_size, num_bins + 1]
|
||||||
|
|
||||||
|
# Normalize: z = cdf(max_edge) - cdf(min_edge)
|
||||||
|
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
|
||||||
|
|
||||||
|
# Bin probabilities = differences of consecutive CDF values, normalized
|
||||||
|
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
|
||||||
|
|
||||||
|
return bin_probabilities
|
||||||
|
|
||||||
|
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
|
||||||
|
"""Dirac delta (C51) projection: split probability between two nearest bins.
|
||||||
|
|
||||||
|
Standard distributional RL projection from Bellemare et al. 2017.
|
||||||
|
"A Distributional Perspective on Reinforcement Learning"
|
||||||
|
arXiv:1707.06887
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_value: [batch_size] or [batch_size, 1] target values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, num_value_bins] target probability distribution.
|
||||||
|
"""
|
||||||
|
if target_value.ndim == 2:
|
||||||
|
target_value = target_value.squeeze(-1)
|
||||||
|
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
|
||||||
|
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||||
|
|
||||||
|
bin_width = self.bin_centers[1] - self.bin_centers[0]
|
||||||
|
normalized_position = (target_value - self.config.value_support_min) / bin_width
|
||||||
|
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
|
||||||
|
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
|
||||||
|
|
||||||
|
weight_upper = normalized_position - lower_bin_idx.float()
|
||||||
|
weight_lower = upper_bin_idx.float() - normalized_position
|
||||||
|
|
||||||
|
same_bin = lower_bin_idx == upper_bin_idx
|
||||||
|
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
|
||||||
|
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
|
||||||
|
|
||||||
|
batch_size = target_value.shape[0]
|
||||||
|
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
|
||||||
|
batch_indices = torch.arange(batch_size, device=target_value.device)
|
||||||
|
target_distribution[batch_indices, lower_bin_idx] += weight_lower
|
||||||
|
target_distribution[batch_indices, upper_bin_idx] += weight_upper
|
||||||
|
|
||||||
|
return target_distribution
|
||||||
|
|
||||||
|
def one_hot_target(self, target_value: Tensor) -> Tensor:
|
||||||
|
"""One-hot target for terminal states (exact return, no smoothing).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_value: [batch_size] or [batch_size, 1] target values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
|
||||||
|
"""
|
||||||
|
if target_value.ndim == 2:
|
||||||
|
target_value = target_value.squeeze(-1)
|
||||||
|
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||||
|
nearest_bin_idx = torch.argmin(
|
||||||
|
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
|
||||||
|
)
|
||||||
|
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
|
||||||
|
|
||||||
|
def compute_target_distribution(
|
||||||
|
self,
|
||||||
|
target_value: Tensor,
|
||||||
|
is_terminal: Tensor,
|
||||||
|
method: str = "hl_gauss",
|
||||||
|
use_one_hot_terminal: bool = True,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Compute target distribution using configured method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_value: [batch_size] scalar return targets
|
||||||
|
is_terminal: [batch_size] boolean terminal flags
|
||||||
|
method: "hl_gauss" or "dirac_delta"
|
||||||
|
use_one_hot_terminal: if True, terminal states get one-hot targets
|
||||||
|
(exact return, no smoothing). If False, all states use the same method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size, num_value_bins] target probability distribution
|
||||||
|
"""
|
||||||
|
if method == "hl_gauss":
|
||||||
|
base_distribution = self.hl_gauss_target(target_value)
|
||||||
|
elif method == "dirac_delta":
|
||||||
|
base_distribution = self.dirac_delta_target(target_value)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
|
||||||
|
|
||||||
|
if not use_one_hot_terminal:
|
||||||
|
return base_distribution
|
||||||
|
|
||||||
|
terminal_distribution = self.one_hot_target(target_value)
|
||||||
|
|
||||||
|
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||||
|
"""Training forward pass — computes cross-entropy loss against MC return targets.
|
||||||
|
|
||||||
|
The batch is expected to be preprocessed by the processor pipeline.
|
||||||
|
Keys expected in batch:
|
||||||
|
- observation.images.*: [B, C, H, W] preprocessed images
|
||||||
|
- observation.language_tokens: [B, seq_len] tokenized task prompt
|
||||||
|
- observation.language_attention_mask: [B, seq_len] padding mask
|
||||||
|
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
|
||||||
|
- is_terminal: [B] boolean terminal flags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(loss, output_dict) where loss is scalar cross-entropy
|
||||||
|
"""
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
|
# Get first image key from batch
|
||||||
|
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||||
|
if not image_keys:
|
||||||
|
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||||
|
images = batch[image_keys[0]]
|
||||||
|
|
||||||
|
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||||
|
mc_return = batch["mc_return"]
|
||||||
|
is_terminal = batch["is_terminal"]
|
||||||
|
|
||||||
|
# Embed observations
|
||||||
|
vision_features = self.embed_image(images)
|
||||||
|
text_embeddings = self.embed_text(token_ids)
|
||||||
|
|
||||||
|
# Forward through value function transformer
|
||||||
|
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||||
|
value_logits = vf_output["logits"]
|
||||||
|
predicted_value = vf_output["value"]
|
||||||
|
|
||||||
|
# Compute target distribution
|
||||||
|
target_distribution = self.compute_target_distribution(
|
||||||
|
mc_return,
|
||||||
|
is_terminal,
|
||||||
|
method=self.config.target_method,
|
||||||
|
use_one_hot_terminal=self.config.use_one_hot_terminal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
|
||||||
|
log_probs = F.log_softmax(value_logits, dim=-1)
|
||||||
|
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
|
||||||
|
|
||||||
|
output_dict = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"predicted_value_mean": predicted_value.mean().item(),
|
||||||
|
"mc_return_mean": mc_return.mean().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return loss, output_dict
|
||||||
|
|
||||||
|
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Compute V(s) for a batch of observations. Used for advantage scoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: preprocessed batch with images and tokenized text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch_size] tensor of predicted values V(s)
|
||||||
|
"""
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
|
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||||
|
if not image_keys:
|
||||||
|
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||||
|
images = batch[image_keys[0]]
|
||||||
|
|
||||||
|
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||||
|
|
||||||
|
vision_features = self.embed_image(images)
|
||||||
|
text_embeddings = self.embed_text(token_ids)
|
||||||
|
|
||||||
|
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||||
|
return vf_output["value"].squeeze(-1) # [batch_size]
|
||||||
+235
@@ -0,0 +1,235 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Processor for RECAP's distributional value function.
|
||||||
|
|
||||||
|
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||||
|
https://pi.website/blog/pistar06
|
||||||
|
|
||||||
|
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
|
||||||
|
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
|
||||||
|
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
|
||||||
|
|
||||||
|
Training targets (mc_return, is_terminal) are NOT routed through the processor.
|
||||||
|
They are dataset columns read directly from the batch in the model's forward().
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
TokenizerProcessorStep,
|
||||||
|
batch_to_transition,
|
||||||
|
policy_action_to_transition,
|
||||||
|
transition_to_batch,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import to_tensor
|
||||||
|
from lerobot.types import EnvTransition, TransitionKey
|
||||||
|
from lerobot.utils.constants import (
|
||||||
|
OBS_IMAGES,
|
||||||
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||||
|
|
||||||
|
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
|
||||||
|
@dataclass
|
||||||
|
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
|
||||||
|
"""Format the task string for the distributional value function.
|
||||||
|
|
||||||
|
The value function receives only visual observations and task text.
|
||||||
|
Builds prompt: "Task: {task}."
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_key: str = "task"
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
transition = transition.copy()
|
||||||
|
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
tasks = complementary_data.get(self.task_key)
|
||||||
|
if tasks is None:
|
||||||
|
raise ValueError("No task found in complementary data")
|
||||||
|
|
||||||
|
if isinstance(tasks, str):
|
||||||
|
tasks = [tasks]
|
||||||
|
|
||||||
|
full_prompts = []
|
||||||
|
for task in tasks:
|
||||||
|
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||||
|
full_prompts.append(f"Task: {cleaned_text}.")
|
||||||
|
|
||||||
|
new_complementary_data = dict(complementary_data)
|
||||||
|
new_complementary_data[self.task_key] = full_prompts
|
||||||
|
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {"task_key": self.task_key}
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
|
||||||
|
@dataclass
|
||||||
|
class DistributionalVFImagePreprocessorStep(ProcessorStep):
|
||||||
|
"""Resize and normalize images for the value function's SigLIP vision tower.
|
||||||
|
|
||||||
|
Expects float images in [0, 1].
|
||||||
|
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
|
||||||
|
- Scale to [-1, 1] for SigLIP
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_resolution: tuple[int, int] = (224, 224)
|
||||||
|
image_keys: tuple[str, ...] | None = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
|
||||||
|
|
||||||
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if not isinstance(observation, dict):
|
||||||
|
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
|
||||||
|
|
||||||
|
image_keys = self.image_keys or tuple(
|
||||||
|
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
|
||||||
|
)
|
||||||
|
if not image_keys:
|
||||||
|
raise KeyError(
|
||||||
|
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_observation = dict(observation)
|
||||||
|
for image_key in image_keys:
|
||||||
|
image = new_observation[image_key]
|
||||||
|
if not isinstance(image, Tensor):
|
||||||
|
image = to_tensor(image)
|
||||||
|
if image.dtype != torch.float32:
|
||||||
|
image = image.to(torch.float32)
|
||||||
|
|
||||||
|
is_channels_first = image.ndim == 4 and image.shape[1] == 3
|
||||||
|
if is_channels_first:
|
||||||
|
image = image.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
if image.shape[1:3] != self.image_resolution:
|
||||||
|
image = resize_with_pad_torch(image, *self.image_resolution)
|
||||||
|
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
|
if is_channels_first:
|
||||||
|
image = image.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
new_observation[image_key] = image
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"image_resolution": self.image_resolution,
|
||||||
|
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
|
||||||
|
return tuple(
|
||||||
|
feature_name
|
||||||
|
for feature_name, feature in config.input_features.items()
|
||||||
|
if feature.type == FeatureType.VISUAL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_distributional_vf_pre_post_processors(
|
||||||
|
config: DistributionalVFConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Create pre/post processors for the distributional value function.
|
||||||
|
|
||||||
|
Preprocessor steps:
|
||||||
|
1. Rename observations (no-op by default)
|
||||||
|
2. Add a batch dimension
|
||||||
|
3. Normalize features (images use identity, so they stay in [0, 1])
|
||||||
|
4. Format task prompt: "Task: {task}."
|
||||||
|
5. Tokenize with the PaliGemma tokenizer
|
||||||
|
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
|
||||||
|
7. Move tensors to the configured device
|
||||||
|
|
||||||
|
Training targets (mc_return, is_terminal) are not processed here.
|
||||||
|
The model reads them directly from the batch in forward().
|
||||||
|
|
||||||
|
The postprocessor is a no-op because the value function does not need
|
||||||
|
action postprocessing.
|
||||||
|
"""
|
||||||
|
image_keys = _visual_image_keys(config)
|
||||||
|
|
||||||
|
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=[
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DistributionalVFPrepareTaskPromptStep(),
|
||||||
|
TokenizerProcessorStep(
|
||||||
|
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
|
||||||
|
max_length=config.tokenizer_max_length,
|
||||||
|
padding_side="right",
|
||||||
|
padding="max_length",
|
||||||
|
),
|
||||||
|
DistributionalVFImagePreprocessorStep(
|
||||||
|
image_resolution=config.image_resolution,
|
||||||
|
image_keys=image_keys or None,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device=config.device or "cpu"),
|
||||||
|
],
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=batch_to_transition,
|
||||||
|
to_output=transition_to_batch,
|
||||||
|
)
|
||||||
|
postprocessor = PolicyProcessorPipeline(
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
)
|
||||||
|
return preprocessor, postprocessor
|
||||||
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
|
|||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
|
|
||||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||||
|
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
|
||||||
from .pretrained import PreTrainedRewardModel
|
from .pretrained import PreTrainedRewardModel
|
||||||
from .robometer.configuration_robometer import RobometerConfig
|
from .robometer.configuration_robometer import RobometerConfig
|
||||||
from .sarm.configuration_sarm import SARMConfig
|
from .sarm.configuration_sarm import SARMConfig
|
||||||
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
|||||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||||
|
|
||||||
return TOPRewardModel
|
return TOPRewardModel
|
||||||
|
elif name == "distributional_value_function":
|
||||||
|
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||||
|
DistributionalVFRewardModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DistributionalVFRewardModel
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return _get_reward_model_cls_from_name(name=name)
|
return _get_reward_model_cls_from_name(name=name)
|
||||||
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
|||||||
return RobometerConfig(**kwargs)
|
return RobometerConfig(**kwargs)
|
||||||
elif reward_type == "topreward":
|
elif reward_type == "topreward":
|
||||||
return TOPRewardConfig(**kwargs)
|
return TOPRewardConfig(**kwargs)
|
||||||
|
elif reward_type == "distributional_value_function":
|
||||||
|
return DistributionalVFConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||||
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif isinstance(reward_cfg, DistributionalVFConfig):
|
||||||
|
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||||
|
make_distributional_vf_pre_post_processors,
|
||||||
|
)
|
||||||
|
|
||||||
|
return make_distributional_vf_pre_post_processors(
|
||||||
|
config=reward_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
processors = _make_processors_from_reward_model_config(
|
processors = _make_processors_from_reward_model_config(
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from .configs import (
|
|||||||
DAggerKeyboardConfig,
|
DAggerKeyboardConfig,
|
||||||
DAggerPedalConfig,
|
DAggerPedalConfig,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
HighlightStrategyConfig,
|
HighlightStrategyConfig,
|
||||||
RolloutConfig,
|
RolloutConfig,
|
||||||
RolloutStrategyConfig,
|
RolloutStrategyConfig,
|
||||||
@@ -49,6 +50,7 @@ from .inference import (
|
|||||||
from .strategies import (
|
from .strategies import (
|
||||||
BaseStrategy,
|
BaseStrategy,
|
||||||
DAggerStrategy,
|
DAggerStrategy,
|
||||||
|
EpisodicStrategy,
|
||||||
HighlightStrategy,
|
HighlightStrategy,
|
||||||
RolloutStrategy,
|
RolloutStrategy,
|
||||||
SentryStrategy,
|
SentryStrategy,
|
||||||
@@ -66,6 +68,8 @@ __all__ = [
|
|||||||
"HardwareContext",
|
"HardwareContext",
|
||||||
"HighlightStrategy",
|
"HighlightStrategy",
|
||||||
"HighlightStrategyConfig",
|
"HighlightStrategyConfig",
|
||||||
|
"EpisodicStrategy",
|
||||||
|
"EpisodicStrategyConfig",
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"InferenceEngineConfig",
|
"InferenceEngineConfig",
|
||||||
"PolicyContext",
|
"PolicyContext",
|
||||||
|
|||||||
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
|
|||||||
upload: str = "KEY_C"
|
upload: str = "KEY_C"
|
||||||
|
|
||||||
|
|
||||||
|
@RolloutStrategyConfig.register_subclass("episodic")
|
||||||
|
@dataclass
|
||||||
|
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||||
|
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||||
|
|
||||||
|
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||||
|
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||||
|
|
||||||
|
Keyboard controls:
|
||||||
|
Right arrow — end current episode or reset phase early
|
||||||
|
Left arrow — discard current episode and re-record
|
||||||
|
Escape — stop recording session
|
||||||
|
|
||||||
|
In between episodes:
|
||||||
|
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||||
|
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This only applies if there are no teleop leaders specified.
|
||||||
|
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||||
|
# Otherwise, leave the robot in its current position.
|
||||||
|
reset_to_initial_position: bool = True
|
||||||
|
|
||||||
|
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||||
|
# When False, fallback to follower -> leader handover.
|
||||||
|
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||||
|
smooth_leader_to_follower_handover: bool = True
|
||||||
|
|
||||||
|
|
||||||
@RolloutStrategyConfig.register_subclass("dagger")
|
@RolloutStrategyConfig.register_subclass("dagger")
|
||||||
@dataclass
|
@dataclass
|
||||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||||
@@ -229,7 +258,13 @@ class RolloutConfig:
|
|||||||
|
|
||||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||||
needs_dataset = isinstance(
|
needs_dataset = isinstance(
|
||||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
self.strategy,
|
||||||
|
(
|
||||||
|
SentryStrategyConfig,
|
||||||
|
HighlightStrategyConfig,
|
||||||
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
from .base import BaseStrategy
|
from .base import BaseStrategy
|
||||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||||
|
from .episodic import EpisodicStrategy
|
||||||
from .factory import create_strategy
|
from .factory import create_strategy
|
||||||
from .highlight import HighlightStrategy
|
from .highlight import HighlightStrategy
|
||||||
from .sentry import SentryStrategy
|
from .sentry import SentryStrategy
|
||||||
@@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"DAggerPhase",
|
"DAggerPhase",
|
||||||
"DAggerStrategy",
|
"DAggerStrategy",
|
||||||
"HighlightStrategy",
|
"HighlightStrategy",
|
||||||
|
"EpisodicStrategy",
|
||||||
"RolloutStrategy",
|
"RolloutStrategy",
|
||||||
"SentryStrategy",
|
"SentryStrategy",
|
||||||
"create_strategy",
|
"create_strategy",
|
||||||
|
|||||||
@@ -56,10 +56,14 @@ from typing import Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lerobot.common.control_utils import is_headless
|
from lerobot.common.control_utils import (
|
||||||
|
follower_smooth_move_to,
|
||||||
|
is_headless,
|
||||||
|
teleop_smooth_move_to,
|
||||||
|
teleop_supports_feedback,
|
||||||
|
)
|
||||||
from lerobot.datasets import VideoEncodingManager
|
from lerobot.datasets import VideoEncodingManager
|
||||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||||
from lerobot.teleoperators import Teleoperator
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
from lerobot.utils.feature_utils import build_dataset_frame
|
from lerobot.utils.feature_utils import build_dataset_frame
|
||||||
from lerobot.utils.import_utils import _pynput_available
|
from lerobot.utils.import_utils import _pynput_available
|
||||||
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
|
|||||||
|
|
||||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||||
from ..context import RolloutContext
|
from ..context import RolloutContext
|
||||||
from ..robot_wrapper import ThreadSafeRobot
|
|
||||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||||
|
|
||||||
PYNPUT_AVAILABLE = _pynput_available
|
PYNPUT_AVAILABLE = _pynput_available
|
||||||
@@ -171,64 +174,6 @@ class DAggerEvents:
|
|||||||
self.upload_requested.clear()
|
self.upload_requested.clear()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Teleoperator helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
|
||||||
"""Return True when the teleop can receive position feedback (is actuated).
|
|
||||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
bool(teleop.feedback_features)
|
|
||||||
and hasattr(teleop, "disable_torque")
|
|
||||||
and hasattr(teleop, "enable_torque")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _teleop_smooth_move_to(
|
|
||||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
|
||||||
) -> None:
|
|
||||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
|
||||||
|
|
||||||
Requires the teleoperator to support feedback
|
|
||||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
|
||||||
|
|
||||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
|
||||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
|
||||||
"""
|
|
||||||
teleop.enable_torque()
|
|
||||||
current = teleop.get_action()
|
|
||||||
steps = max(int(duration_s * fps), 1)
|
|
||||||
|
|
||||||
for step in range(steps + 1):
|
|
||||||
t = step / steps
|
|
||||||
interp = {
|
|
||||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
|
||||||
}
|
|
||||||
teleop.send_feedback(interp)
|
|
||||||
time.sleep(1 / fps)
|
|
||||||
|
|
||||||
|
|
||||||
def _follower_smooth_move_to(
|
|
||||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
|
||||||
) -> None:
|
|
||||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
|
||||||
|
|
||||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
|
||||||
to the follower, we bring the follower to the teleop's current pose.
|
|
||||||
Both ``current`` and ``target`` must be in robot-action key space.
|
|
||||||
"""
|
|
||||||
steps = max(int(duration_s * fps), 1)
|
|
||||||
|
|
||||||
for step in range(steps + 1):
|
|
||||||
t = step / steps
|
|
||||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
|
||||||
robot.send_action(interp)
|
|
||||||
time.sleep(1 / fps)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Input device handlers
|
# Input device handlers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
logger.info("Pausing engine - robot holds position")
|
logger.info("Pausing engine - robot holds position")
|
||||||
engine.pause()
|
engine.pause()
|
||||||
|
|
||||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||||
logger.info("Smooth handover: moving leader arm to follower position")
|
logger.info("Smooth handover: moving leader arm to follower position")
|
||||||
_teleop_smooth_move_to(teleop, prev_action)
|
teleop_smooth_move_to(teleop, prev_action)
|
||||||
|
|
||||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||||
logger.info("Entering correction mode - human teleop control")
|
logger.info("Entering correction mode - human teleop control")
|
||||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||||
logger.info("Smooth handover: sliding follower to teleop position")
|
logger.info("Smooth handover: sliding follower to teleop position")
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
teleop_action = teleop.get_action()
|
teleop_action = teleop.get_action()
|
||||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||||
target = ctx.processors.robot_action_processor((processed, obs))
|
target = ctx.processors.robot_action_processor((processed, obs))
|
||||||
_follower_smooth_move_to(robot, prev_action, target)
|
follower_smooth_move_to(robot, prev_action, target)
|
||||||
|
|
||||||
# unlock the teleop for human control
|
# unlock the teleop for human control
|
||||||
if _teleop_supports_feedback(teleop):
|
if teleop_supports_feedback(teleop):
|
||||||
teleop.disable_torque()
|
teleop.disable_torque()
|
||||||
|
|
||||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||||
if _teleop_supports_feedback(teleop):
|
if teleop_supports_feedback(teleop):
|
||||||
teleop.enable_torque()
|
teleop.enable_torque()
|
||||||
|
|
||||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||||
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
engine.resume()
|
engine.resume()
|
||||||
|
|
||||||
# release teleop before resuming the policy
|
# release teleop before resuming the policy
|
||||||
if _teleop_supports_feedback(teleop):
|
if teleop_supports_feedback(teleop):
|
||||||
teleop.disable_torque()
|
teleop.disable_torque()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -0,0 +1,335 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||||
|
|
||||||
|
- Policy drives the robot during each recording episode.
|
||||||
|
- An optional teleoperator can drive the robot during reset phases so the
|
||||||
|
operator can bring the environment back to its starting configuration.
|
||||||
|
If no teleop is connected the robot stays in its current position.
|
||||||
|
- Keyboard controls:
|
||||||
|
|
||||||
|
Right arrow — end the current episode or reset phase early
|
||||||
|
Left arrow — discard the current episode and re-record it
|
||||||
|
Escape — stop the recording session
|
||||||
|
|
||||||
|
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from lerobot.common.control_utils import (
|
||||||
|
follower_smooth_move_to,
|
||||||
|
init_keyboard_listener,
|
||||||
|
is_headless,
|
||||||
|
teleop_smooth_move_to,
|
||||||
|
teleop_supports_feedback,
|
||||||
|
)
|
||||||
|
from lerobot.datasets import VideoEncodingManager
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
from lerobot.utils.utils import log_say
|
||||||
|
from lerobot.utils.visualization_utils import log_rerun_data
|
||||||
|
|
||||||
|
from ..configs import EpisodicStrategyConfig
|
||||||
|
from ..context import RolloutContext
|
||||||
|
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicStrategy(RolloutStrategy):
|
||||||
|
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||||
|
|
||||||
|
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||||
|
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||||
|
follows every episode (except the last) so the operator can manually
|
||||||
|
reset the environment. During the reset phase, an optional teleoperator
|
||||||
|
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||||
|
|
||||||
|
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||||
|
the start of each recording episode.
|
||||||
|
|
||||||
|
Keyboard events:
|
||||||
|
right arrow → end current episode or reset phase early
|
||||||
|
left arrow → discard & re-record current episode
|
||||||
|
ESC → stop the session
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: EpisodicStrategyConfig
|
||||||
|
|
||||||
|
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self._listener = None
|
||||||
|
self._events: dict | None = None
|
||||||
|
|
||||||
|
def setup(self, ctx: RolloutContext) -> None:
|
||||||
|
"""Start the inference engine and attach the keyboard listener."""
|
||||||
|
self._init_engine(ctx)
|
||||||
|
self._listener, self._events = init_keyboard_listener()
|
||||||
|
logger.info("Episodic strategy ready")
|
||||||
|
|
||||||
|
def run(self, ctx: RolloutContext) -> None:
|
||||||
|
"""Main multi-episode recording loop."""
|
||||||
|
cfg = ctx.runtime.cfg
|
||||||
|
dataset_cfg = cfg.dataset
|
||||||
|
robot = ctx.hardware.robot_wrapper
|
||||||
|
teleop = ctx.hardware.teleop
|
||||||
|
dataset = ctx.data.dataset
|
||||||
|
events = self._events
|
||||||
|
features = ctx.data.dataset_features
|
||||||
|
|
||||||
|
fps = cfg.fps
|
||||||
|
episode_time_s = dataset_cfg.episode_time_s
|
||||||
|
reset_time_s = dataset_cfg.reset_time_s
|
||||||
|
num_episodes = dataset_cfg.num_episodes
|
||||||
|
single_task = dataset_cfg.single_task or cfg.task
|
||||||
|
play_sounds = cfg.play_sounds
|
||||||
|
|
||||||
|
display_compressed = (
|
||||||
|
True
|
||||||
|
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||||
|
else cfg.display_compressed_images
|
||||||
|
)
|
||||||
|
|
||||||
|
with VideoEncodingManager(dataset):
|
||||||
|
try:
|
||||||
|
recorded_episodes = 0
|
||||||
|
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||||
|
if ctx.runtime.shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||||
|
self._engine.reset()
|
||||||
|
self._interpolator.reset()
|
||||||
|
self._engine.resume()
|
||||||
|
|
||||||
|
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||||
|
self._policy_loop(
|
||||||
|
ctx=ctx,
|
||||||
|
robot=robot,
|
||||||
|
events=events,
|
||||||
|
features=features,
|
||||||
|
fps=fps,
|
||||||
|
control_time_s=episode_time_s,
|
||||||
|
dataset=dataset,
|
||||||
|
single_task=single_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset phase, skip after the last episode (but run when re-recording)
|
||||||
|
if not events["stop_recording"] and (
|
||||||
|
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||||
|
):
|
||||||
|
log_say("Reset the environment", play_sounds)
|
||||||
|
|
||||||
|
if teleop:
|
||||||
|
# Smooth handover so the transition to teleop control is jerk-free.
|
||||||
|
# For actuated teleops: drive the leader arm to the follower's current
|
||||||
|
# position so the operator takes over without fighting the arm.
|
||||||
|
# For non-actuated teleops: slide the follower to the teleop's current
|
||||||
|
# pose instead, since the leader cannot be driven.
|
||||||
|
obs = robot.get_observation()
|
||||||
|
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||||
|
if (
|
||||||
|
teleop_supports_feedback(teleop)
|
||||||
|
and self.config.smooth_leader_to_follower_handover
|
||||||
|
):
|
||||||
|
logger.info("Smooth handover: moving leader arm to follower position")
|
||||||
|
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||||
|
teleop.disable_torque()
|
||||||
|
else:
|
||||||
|
logger.info("Smooth handover: sliding follower to teleop position")
|
||||||
|
teleop_action = teleop.get_action()
|
||||||
|
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||||
|
target = ctx.processors.robot_action_processor((processed, obs))
|
||||||
|
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||||
|
|
||||||
|
elif self.config.reset_to_initial_position:
|
||||||
|
# No teleop: return the robot to its startup position.
|
||||||
|
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||||
|
|
||||||
|
self._reset_loop(
|
||||||
|
ctx=ctx,
|
||||||
|
robot=robot,
|
||||||
|
teleop=teleop,
|
||||||
|
events=events,
|
||||||
|
fps=fps,
|
||||||
|
control_time_s=reset_time_s,
|
||||||
|
display_data=cfg.display_data,
|
||||||
|
display_compressed=display_compressed,
|
||||||
|
)
|
||||||
|
|
||||||
|
if events["rerecord_episode"]:
|
||||||
|
log_say("Re-record episode", play_sounds)
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
|
||||||
|
# returns to its initial joint positions captured at startup
|
||||||
|
if not teleop and self.config.reset_to_initial_position:
|
||||||
|
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
recorded_episodes += 1
|
||||||
|
finally:
|
||||||
|
# Save any frames buffered in the current episode so an unexpected
|
||||||
|
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||||
|
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||||
|
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
def _policy_loop(
|
||||||
|
self,
|
||||||
|
ctx: RolloutContext,
|
||||||
|
robot,
|
||||||
|
events: dict,
|
||||||
|
features: dict,
|
||||||
|
fps: float,
|
||||||
|
control_time_s: float,
|
||||||
|
dataset,
|
||||||
|
single_task: str,
|
||||||
|
) -> None:
|
||||||
|
"""Policy-driven recording loop for a single episode."""
|
||||||
|
interpolator = self._interpolator
|
||||||
|
control_interval = interpolator.get_control_interval(fps)
|
||||||
|
|
||||||
|
timestamp = 0.0
|
||||||
|
start_t = time.perf_counter()
|
||||||
|
|
||||||
|
while timestamp < control_time_s:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
if events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if ctx.runtime.shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||||
|
|
||||||
|
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||||
|
continue
|
||||||
|
|
||||||
|
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||||
|
|
||||||
|
if action_dict is not None:
|
||||||
|
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||||
|
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||||
|
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||||
|
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||||
|
|
||||||
|
dt = time.perf_counter() - loop_start
|
||||||
|
sleep_t = control_interval - dt
|
||||||
|
if sleep_t < 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||||
|
"Dataset frames might be dropped and robot control might be unstable. "
|
||||||
|
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||||
|
"3) CPU starvation"
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_t, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_t
|
||||||
|
|
||||||
|
def _reset_loop(
|
||||||
|
self,
|
||||||
|
ctx: RolloutContext,
|
||||||
|
robot,
|
||||||
|
teleop,
|
||||||
|
events: dict,
|
||||||
|
fps: float,
|
||||||
|
control_time_s: float,
|
||||||
|
display_data: bool,
|
||||||
|
display_compressed: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||||
|
processors = ctx.processors
|
||||||
|
control_interval = 1.0 / fps
|
||||||
|
|
||||||
|
timestamp = 0.0
|
||||||
|
start_t = time.perf_counter()
|
||||||
|
|
||||||
|
while timestamp < control_time_s:
|
||||||
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
if events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if ctx.runtime.shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
obs = robot.get_observation()
|
||||||
|
|
||||||
|
if teleop is not None:
|
||||||
|
act = teleop.get_action()
|
||||||
|
act_teleop = processors.teleop_action_processor((act, obs))
|
||||||
|
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||||
|
robot.send_action(robot_action)
|
||||||
|
|
||||||
|
if display_data:
|
||||||
|
obs_processed = processors.robot_observation_processor(obs)
|
||||||
|
log_rerun_data(
|
||||||
|
observation=obs_processed,
|
||||||
|
action=act_teleop,
|
||||||
|
compress_images=display_compressed,
|
||||||
|
)
|
||||||
|
|
||||||
|
dt = time.perf_counter() - loop_start
|
||||||
|
sleep_t = control_interval - dt
|
||||||
|
precise_sleep(max(sleep_t, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_t
|
||||||
|
|
||||||
|
def teardown(self, ctx: RolloutContext) -> None:
|
||||||
|
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||||
|
cfg = ctx.runtime.cfg
|
||||||
|
play_sounds = cfg.play_sounds
|
||||||
|
|
||||||
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
|
|
||||||
|
if not is_headless() and self._listener is not None:
|
||||||
|
self._listener.stop()
|
||||||
|
|
||||||
|
if ctx.data.dataset is not None:
|
||||||
|
logger.info("Finalizing dataset...")
|
||||||
|
ctx.data.dataset.finalize()
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.dataset is not None
|
||||||
|
and cfg.dataset.push_to_hub
|
||||||
|
and ctx.data.dataset is not None
|
||||||
|
and safe_push_to_hub(
|
||||||
|
ctx.data.dataset,
|
||||||
|
tags=cfg.dataset.tags,
|
||||||
|
private=cfg.dataset.private,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
logger.info("Dataset uploaded to hub")
|
||||||
|
log_say("Dataset uploaded to hub", play_sounds)
|
||||||
|
|
||||||
|
self._teardown_hardware(
|
||||||
|
ctx.hardware,
|
||||||
|
return_to_initial_position=cfg.return_to_initial_position,
|
||||||
|
)
|
||||||
|
log_say("Exiting", play_sounds)
|
||||||
|
logger.info("Episodic strategy teardown complete")
|
||||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
|||||||
from .base import BaseStrategy
|
from .base import BaseStrategy
|
||||||
from .core import RolloutStrategy
|
from .core import RolloutStrategy
|
||||||
from .dagger import DAggerStrategy
|
from .dagger import DAggerStrategy
|
||||||
|
from .episodic import EpisodicStrategy
|
||||||
from .highlight import HighlightStrategy
|
from .highlight import HighlightStrategy
|
||||||
from .sentry import SentryStrategy
|
from .sentry import SentryStrategy
|
||||||
|
|
||||||
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
|||||||
return HighlightStrategy(config)
|
return HighlightStrategy(config)
|
||||||
if config.type == "dagger":
|
if config.type == "dagger":
|
||||||
return DAggerStrategy(config)
|
return DAggerStrategy(config)
|
||||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
if config.type == "episodic":
|
||||||
|
return EpisodicStrategy(config)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ Strategies
|
|||||||
--strategy.type=sentry Continuous recording with auto-upload
|
--strategy.type=sentry Continuous recording with auto-upload
|
||||||
--strategy.type=highlight Ring buffer + keystroke save
|
--strategy.type=highlight Ring buffer + keystroke save
|
||||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||||
|
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||||
|
|
||||||
Inference backends
|
Inference backends
|
||||||
------------------
|
------------------
|
||||||
@@ -111,6 +112,18 @@ Usage examples
|
|||||||
--display_data=true \\
|
--display_data=true \\
|
||||||
--use_torch_compile=true
|
--use_torch_compile=true
|
||||||
|
|
||||||
|
# Episodic mode — episode-oriented recording with reset phases
|
||||||
|
lerobot-rollout \\
|
||||||
|
--strategy.type=episodic \\
|
||||||
|
--policy.path=user/my_policy \\
|
||||||
|
--robot.type=so100_follower \\
|
||||||
|
--robot.port=/dev/ttyACM0 \\
|
||||||
|
--teleop.type=so100_leader \\
|
||||||
|
--teleop.port=/dev/ttyACM1 \\
|
||||||
|
--dataset.repo_id=user/rollout_episodic_data \\
|
||||||
|
--dataset.num_episodes=20 \\
|
||||||
|
--dataset.single_task="Grab the cube"
|
||||||
|
|
||||||
# Resume a previous sentry recording session
|
# Resume a previous sentry recording session
|
||||||
lerobot-rollout \\
|
lerobot-rollout \\
|
||||||
--strategy.type=sentry \\
|
--strategy.type=sentry \\
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from typing import Any
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class MockLazyTensorStateStep(ProcessorStep):
|
||||||
|
"""Mock step whose tensor state is not present in constructor config."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.scale = scale
|
||||||
|
self.tensor_state: torch.Tensor | None = None
|
||||||
|
|
||||||
|
if initial_value is not None:
|
||||||
|
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Return the transition unchanged."""
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return constructor config while intentionally omitting tensor state."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"scale": self.scale,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
"""Return tensor state only after it has been initialized or loaded."""
|
||||||
|
if self.tensor_state is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {"tensor_state": self.tensor_state}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
"""Load tensor state."""
|
||||||
|
self.tensor_state = state["tensor_state"].clone()
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""Return features unchanged."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||||
|
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||||
|
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||||
|
|
||||||
|
|
||||||
def test_empty_pipeline():
|
def test_empty_pipeline():
|
||||||
"""Test pipeline with no steps."""
|
"""Test pipeline with no steps."""
|
||||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
|||||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_config_matches_saved_json():
|
||||||
|
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||||
|
stateless_step = MockStep(name="stateless")
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||||
|
|
||||||
|
in_memory_config = pipeline.get_config()
|
||||||
|
|
||||||
|
assert pipeline.get_config() == in_memory_config
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||||
|
with open(config_path) as file_pointer:
|
||||||
|
saved_config = json.load(file_pointer)
|
||||||
|
|
||||||
|
assert in_memory_config == saved_config
|
||||||
|
assert "state_file" not in in_memory_config["steps"][0]
|
||||||
|
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict_matches_saved_safetensors():
|
||||||
|
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||||
|
|
||||||
|
in_memory_state_dict = pipeline.state_dict()
|
||||||
|
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||||
|
state_key = "stateful_pipeline_step_0"
|
||||||
|
|
||||||
|
assert set(in_memory_state_dict) == {state_key}
|
||||||
|
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||||
|
|
||||||
|
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||||
|
assert stateful_step.tensor_state is not None
|
||||||
|
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||||
|
|
||||||
|
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||||
|
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pipeline.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
save_path = Path(tmp_dir)
|
||||||
|
assert (save_path / "policy_preprocessor.json").exists()
|
||||||
|
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_round_trips_stateful_pipeline():
|
||||||
|
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert len(loaded_pipeline) == 1
|
||||||
|
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||||
|
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||||
|
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||||
|
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||||
|
|
||||||
|
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||||
|
assert config["steps"][0]["state_file"] == state_filename
|
||||||
|
assert set(pipeline_state_dict) == {state_key}
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||||
|
assert loaded_step.tensor_state is not None
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||||
|
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||||
|
assert loaded_step.state_dict() == {}
|
||||||
|
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||||
|
|
||||||
|
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||||
|
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_config_applies_overrides_before_state_loading():
|
||||||
|
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
pipeline_state_dict = pipeline.state_dict()
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||||
|
config,
|
||||||
|
state_dict=pipeline_state_dict,
|
||||||
|
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||||
|
)
|
||||||
|
loaded_step = loaded_pipeline.steps[0]
|
||||||
|
|
||||||
|
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||||
|
assert loaded_step.scale == 5.0
|
||||||
|
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_state_dict_raises_on_missing_expected_state():
|
||||||
|
"""Test loading raises when serialized config expects missing state."""
|
||||||
|
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||||
|
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||||
|
loaded_pipeline.load_state_dict({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||||
|
"""Test loading raises on unexpected top-level state keys."""
|
||||||
|
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="extra"):
|
||||||
|
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||||
|
|
||||||
|
|
||||||
|
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||||
|
"""Test stateless in-memory serialization and loading."""
|
||||||
|
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||||
|
config = pipeline.get_config()
|
||||||
|
config_without_name = {"steps": config["steps"]}
|
||||||
|
|
||||||
|
assert pipeline.state_dict() == {}
|
||||||
|
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||||
|
|
||||||
|
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||||
|
|
||||||
|
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||||
|
assert loaded_pipeline.state_dict() == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||||
|
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||||
|
"""Test from_config reports invalid top-level config values cleanly."""
|
||||||
|
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||||
|
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
class MockModuleStep(ProcessorStep, nn.Module):
|
class MockModuleStep(ProcessorStep, nn.Module):
|
||||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,518 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for RECAP's distributional value function."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
|
||||||
|
DistributionalVFConfig,
|
||||||
|
)
|
||||||
|
from lerobot.types import TransitionKey
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES
|
||||||
|
from tests.utils import skip_if_package_missing
|
||||||
|
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
NUM_BINS = 201
|
||||||
|
IMAGE_KEY = f"{OBS_IMAGES}.top"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(**overrides) -> DistributionalVFConfig:
|
||||||
|
defaults = {
|
||||||
|
"init_from_actor_path": "",
|
||||||
|
"device": "cpu",
|
||||||
|
"image_resolution": (224, 224),
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
config = DistributionalVFConfig(**defaults)
|
||||||
|
config.input_features = {
|
||||||
|
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
}
|
||||||
|
config.output_features = {}
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model():
|
||||||
|
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||||
|
DistributionalVFRewardModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DistributionalVFRewardModel(_make_config())
|
||||||
|
|
||||||
|
|
||||||
|
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
|
||||||
|
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
|
return {
|
||||||
|
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
|
||||||
|
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
|
||||||
|
"mc_return": torch.rand(batch_size, device=device) * -1.0,
|
||||||
|
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_registered_in_reward_model_registry():
|
||||||
|
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
|
||||||
|
known = RewardModelConfig.get_known_choices()
|
||||||
|
assert "distributional_value_function" in known
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_returns_correct_class():
|
||||||
|
"""get_reward_model_class returns DistributionalVFRewardModel."""
|
||||||
|
from lerobot.rewards.factory import get_reward_model_class
|
||||||
|
|
||||||
|
cls = get_reward_model_class("distributional_value_function")
|
||||||
|
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||||
|
DistributionalVFRewardModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cls is DistributionalVFRewardModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_reward_model_config_factory():
|
||||||
|
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
|
||||||
|
from lerobot.rewards.factory import make_reward_model_config
|
||||||
|
|
||||||
|
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
|
||||||
|
assert isinstance(config, DistributionalVFConfig)
|
||||||
|
assert config.num_value_bins == 101
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_hl_gauss_sums_to_one():
|
||||||
|
"""HL-Gauss target distribution sums to 1 for each sample."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
|
||||||
|
dist = model.hl_gauss_target(targets)
|
||||||
|
|
||||||
|
assert dist.shape == (4, NUM_BINS)
|
||||||
|
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_hl_gauss_non_negative():
|
||||||
|
"""HL-Gauss target probabilities are all non-negative."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.linspace(-1.0, 0.0, 10)
|
||||||
|
dist = model.hl_gauss_target(targets)
|
||||||
|
|
||||||
|
assert (dist >= 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_hl_gauss_expected_value_matches():
|
||||||
|
"""E[V] under HL-Gauss distribution matches the target value."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||||
|
dist = model.hl_gauss_target(targets)
|
||||||
|
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||||
|
|
||||||
|
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_hl_gauss_handles_2d_input():
|
||||||
|
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
|
||||||
|
dist = model.hl_gauss_target(targets)
|
||||||
|
|
||||||
|
assert dist.shape == (2, NUM_BINS)
|
||||||
|
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_dirac_delta_sums_to_one():
|
||||||
|
"""Dirac delta target distribution sums to 1 for each sample."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
|
||||||
|
dist = model.dirac_delta_target(targets)
|
||||||
|
|
||||||
|
assert dist.shape == (5, NUM_BINS)
|
||||||
|
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_dirac_delta_at_most_two_nonzero():
|
||||||
|
"""Dirac delta places probability on at most two adjacent bins."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.7523, -0.0013])
|
||||||
|
dist = model.dirac_delta_target(targets)
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
assert (dist[i] > 0).sum() <= 2
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_dirac_delta_expected_value_matches():
|
||||||
|
"""E[V] under Dirac delta distribution matches the target value."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||||
|
dist = model.dirac_delta_target(targets)
|
||||||
|
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||||
|
|
||||||
|
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_dirac_delta_boundary_values_clamped():
|
||||||
|
"""Values outside support are clamped to boundary bins."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-1.5, 0.5])
|
||||||
|
dist = model.dirac_delta_target(targets)
|
||||||
|
|
||||||
|
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
|
||||||
|
assert dist[0, 0] == 1.0
|
||||||
|
assert dist[1, -1] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_one_hot_single_nonzero():
|
||||||
|
"""One-hot target has exactly one non-zero bin per sample."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
|
||||||
|
dist = model.one_hot_target(targets)
|
||||||
|
|
||||||
|
assert dist.shape == (4, NUM_BINS)
|
||||||
|
for i in range(4):
|
||||||
|
assert (dist[i] > 0).sum() == 1
|
||||||
|
assert dist[i].sum() == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_one_hot_nearest_bin():
|
||||||
|
"""One-hot target activates the bin closest to the target value."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5])
|
||||||
|
dist = model.one_hot_target(targets)
|
||||||
|
|
||||||
|
hot_idx = dist[0].argmax()
|
||||||
|
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_terminal_gets_one_hot():
|
||||||
|
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
|
||||||
|
is_terminal = torch.tensor([False, True, False, True])
|
||||||
|
|
||||||
|
dist = model.compute_target_distribution(
|
||||||
|
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
|
||||||
|
assert (dist[1] > 0).sum() == 1
|
||||||
|
assert (dist[3] > 0).sum() == 1
|
||||||
|
assert (dist[0] > 0).sum() > 2
|
||||||
|
assert (dist[2] > 0).sum() > 2
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_no_terminal_override_when_disabled():
|
||||||
|
"""When use_one_hot_terminal=False, terminal states use the base method."""
|
||||||
|
model = _make_model()
|
||||||
|
targets = torch.tensor([-0.5, -0.3])
|
||||||
|
is_terminal = torch.tensor([False, True])
|
||||||
|
|
||||||
|
dist = model.compute_target_distribution(
|
||||||
|
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (dist[1] > 0).sum() > 2
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_model_has_expected_components():
|
||||||
|
"""Model scaffold contains all architectural components."""
|
||||||
|
model = _make_model()
|
||||||
|
|
||||||
|
assert hasattr(model, "vision_tower")
|
||||||
|
assert hasattr(model, "multi_modal_projector")
|
||||||
|
assert hasattr(model, "token_embedding")
|
||||||
|
assert hasattr(model, "layers")
|
||||||
|
assert hasattr(model, "value_head")
|
||||||
|
assert hasattr(model, "cls_embedding")
|
||||||
|
assert hasattr(model, "norm")
|
||||||
|
assert hasattr(model, "rotary_emb")
|
||||||
|
assert hasattr(model, "bin_centers")
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_model_bin_centers_shape():
|
||||||
|
"""Bin centers buffer has shape (num_value_bins,)."""
|
||||||
|
model = _make_model()
|
||||||
|
assert model.bin_centers.shape == (NUM_BINS,)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_model_layer_count():
|
||||||
|
"""Transformer has num_hidden_layers (6) layers."""
|
||||||
|
model = _make_model()
|
||||||
|
assert len(model.layers) == 6
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_model_value_head_output_dim():
|
||||||
|
"""Value head outputs num_value_bins logits."""
|
||||||
|
model = _make_model()
|
||||||
|
assert model.value_head.out_features == NUM_BINS
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_forward_returns_loss_and_dict():
|
||||||
|
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
|
||||||
|
model = _make_model()
|
||||||
|
batch = _make_batch()
|
||||||
|
|
||||||
|
loss, output_dict = model.forward(batch)
|
||||||
|
|
||||||
|
assert loss.shape == ()
|
||||||
|
assert torch.isfinite(loss)
|
||||||
|
assert "loss" in output_dict
|
||||||
|
assert "predicted_value_mean" in output_dict
|
||||||
|
assert "mc_return_mean" in output_dict
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_forward_loss_is_positive():
|
||||||
|
"""Cross-entropy loss is strictly positive for random weights."""
|
||||||
|
model = _make_model()
|
||||||
|
batch = _make_batch()
|
||||||
|
|
||||||
|
loss, _ = model.forward(batch)
|
||||||
|
|
||||||
|
assert loss.item() > 0
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_compute_reward_returns_correct_shape():
|
||||||
|
"""compute_reward returns [batch_size] tensor of finite float32 values."""
|
||||||
|
model = _make_model()
|
||||||
|
model.eval()
|
||||||
|
batch = _make_batch(batch_size=3)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
values = model.compute_reward(batch)
|
||||||
|
|
||||||
|
assert values.shape == (3,)
|
||||||
|
assert values.dtype == torch.float32
|
||||||
|
assert torch.isfinite(values).all()
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_compute_reward_values_in_support_range():
|
||||||
|
"""Predicted values lie within [value_support_min, value_support_max]."""
|
||||||
|
model = _make_model()
|
||||||
|
model.eval()
|
||||||
|
batch = _make_batch(batch_size=8)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
values = model.compute_reward(batch)
|
||||||
|
|
||||||
|
assert (values >= -1.0 - 0.01).all()
|
||||||
|
assert (values <= 0.0 + 0.01).all()
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_processor_pipeline_produces_expected_keys():
|
||||||
|
"""Full preprocessor pipeline produces tokenized text and processed images."""
|
||||||
|
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||||
|
make_distributional_vf_pre_post_processors,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
|
||||||
|
|
||||||
|
raw_batch = {
|
||||||
|
IMAGE_KEY: torch.rand(3, 224, 224),
|
||||||
|
"task": "pick up the cup",
|
||||||
|
}
|
||||||
|
|
||||||
|
processed = preprocessor(raw_batch)
|
||||||
|
|
||||||
|
assert OBS_LANGUAGE_TOKENS in processed
|
||||||
|
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||||
|
assert IMAGE_KEY in processed
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_gradient_flows_through_value_head():
|
||||||
|
"""Backprop produces non-zero gradients on the value head."""
|
||||||
|
model = _make_model()
|
||||||
|
model.train()
|
||||||
|
batch = _make_batch()
|
||||||
|
|
||||||
|
loss, _ = model.forward(batch)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
assert model.value_head.weight.grad is not None
|
||||||
|
assert not torch.all(model.value_head.weight.grad == 0)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_gradient_flows_through_cls_embedding():
|
||||||
|
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
|
||||||
|
model = _make_model()
|
||||||
|
model.train()
|
||||||
|
batch = _make_batch()
|
||||||
|
|
||||||
|
loss, _ = model.forward(batch)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
assert model.cls_embedding.grad is not None
|
||||||
|
assert not torch.all(model.cls_embedding.grad == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_requires_visual_feature():
|
||||||
|
"""validate_features raises if no VISUAL feature is present."""
|
||||||
|
config = DistributionalVFConfig(init_from_actor_path="")
|
||||||
|
config.input_features = {
|
||||||
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="VISUAL"):
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_passes_with_visual_feature():
|
||||||
|
"""validate_features succeeds when a VISUAL feature is present."""
|
||||||
|
config = _make_config()
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_save_load_pretrained_roundtrip(tmp_path):
|
||||||
|
"""Saved model can be loaded back with identical weights."""
|
||||||
|
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||||
|
DistributionalVFRewardModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = _make_model()
|
||||||
|
model._save_pretrained(tmp_path)
|
||||||
|
|
||||||
|
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
|
||||||
|
|
||||||
|
orig_sd = model.state_dict()
|
||||||
|
loaded_sd = loaded.state_dict()
|
||||||
|
|
||||||
|
assert set(orig_sd.keys()) == set(loaded_sd.keys())
|
||||||
|
for key in orig_sd:
|
||||||
|
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_image_preprocessor_normalizes_to_minus_one_one():
|
||||||
|
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
|
||||||
|
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||||
|
DistributionalVFImagePreprocessorStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
IMAGE_KEY: torch.rand(1, 224, 224, 3),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = step(transition)
|
||||||
|
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||||
|
|
||||||
|
assert image.min() >= -1.0 - 1e-5
|
||||||
|
assert image.max() <= 1.0 + 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_package_missing("transformers")
|
||||||
|
def test_image_preprocessor_resizes_with_pad():
|
||||||
|
"""Image preprocessor resizes non-square images to target resolution."""
|
||||||
|
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||||
|
DistributionalVFImagePreprocessorStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
IMAGE_KEY: torch.rand(1, 480, 640, 3),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = step(transition)
|
||||||
|
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||||
|
|
||||||
|
assert image.shape[1:3] == (224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_prompt_formats_correctly():
|
||||||
|
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
|
||||||
|
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||||
|
DistributionalVFPrepareTaskPromptStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
step = DistributionalVFPrepareTaskPromptStep()
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = step(transition)
|
||||||
|
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||||
|
|
||||||
|
assert prompt == "Task: pick up the cup."
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_prompt_handles_string_input():
|
||||||
|
"""Task prompt step accepts a plain string (not just a list)."""
|
||||||
|
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||||
|
DistributionalVFPrepareTaskPromptStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
step = DistributionalVFPrepareTaskPromptStep()
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = step(transition)
|
||||||
|
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||||
|
|
||||||
|
assert prompt == "Task: open drawer."
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_prompt_raises_on_missing_task():
|
||||||
|
"""Task prompt step raises ValueError when task key is absent."""
|
||||||
|
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||||
|
DistributionalVFPrepareTaskPromptStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
step = DistributionalVFPrepareTaskPromptStep()
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No task found"):
|
||||||
|
step(transition)
|
||||||
@@ -59,6 +59,7 @@ def test_strategy_config_types():
|
|||||||
from lerobot.rollout import (
|
from lerobot.rollout import (
|
||||||
BaseStrategyConfig,
|
BaseStrategyConfig,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
HighlightStrategyConfig,
|
HighlightStrategyConfig,
|
||||||
SentryStrategyConfig,
|
SentryStrategyConfig,
|
||||||
)
|
)
|
||||||
@@ -67,6 +68,7 @@ def test_strategy_config_types():
|
|||||||
assert SentryStrategyConfig().type == "sentry"
|
assert SentryStrategyConfig().type == "sentry"
|
||||||
assert HighlightStrategyConfig().type == "highlight"
|
assert HighlightStrategyConfig().type == "highlight"
|
||||||
assert DAggerStrategyConfig().type == "dagger"
|
assert DAggerStrategyConfig().type == "dagger"
|
||||||
|
assert EpisodicStrategyConfig().type == "episodic"
|
||||||
|
|
||||||
|
|
||||||
def test_dagger_config_invalid_input_device():
|
def test_dagger_config_invalid_input_device():
|
||||||
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
|
|||||||
BaseStrategyConfig,
|
BaseStrategyConfig,
|
||||||
DAggerStrategy,
|
DAggerStrategy,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
|
EpisodicStrategy,
|
||||||
|
EpisodicStrategyConfig,
|
||||||
SentryStrategy,
|
SentryStrategy,
|
||||||
SentryStrategyConfig,
|
SentryStrategyConfig,
|
||||||
create_strategy,
|
create_strategy,
|
||||||
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
|
|||||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||||
|
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
|
||||||
|
|
||||||
|
|
||||||
def test_create_strategy_unknown_raises():
|
def test_create_strategy_unknown_raises():
|
||||||
|
|||||||
Reference in New Issue
Block a user