mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
Compare commits
9 Commits
e7886c2285
...
1873668521
| Author | SHA1 | Date | |
|---|---|---|---|
| 1873668521 | |||
| 28ee32f721 | |||
| d28316d1b6 | |||
| 9e83510c99 | |||
| 1f7b03f5f2 | |||
| cb8edf17e6 | |||
| 5699f6cbf4 | |||
| 0e6114ac36 | |||
| c8ce413d73 |
@@ -19,8 +19,8 @@ on:
|
|||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
# Runs at 02:00
|
# Runs at 02:00
|
||||||
schedule:
|
# schedule:
|
||||||
- cron: "0 2 * * *"
|
# - cron: "0 2 * * *"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
CLOSE_ISSUE_MESSAGE: >
|
CLOSE_ISSUE_MESSAGE: >
|
||||||
|
|||||||
+4
-4
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core ML
|
# Core ML
|
||||||
"torch>=2.7,<2.11.0",
|
"torch>=2.7,<2.13.0",
|
||||||
"torchvision>=0.22.0,<0.26.0",
|
"torchvision>=0.22.0,<0.28.0",
|
||||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||||
"Pillow>=10.0.0,<13.0.0",
|
"Pillow>=10.0.0,<13.0.0",
|
||||||
@@ -99,7 +99,7 @@ dataset = [
|
|||||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"lerobot[av-dep]",
|
"lerobot[av-dep]",
|
||||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
"torchcodec>=0.3.0,<0.13.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10), 0.11 needs torch==2.11, 0.12 needs torch==2.12.
|
||||||
"jsonlines>=4.0.0,<5.0.0",
|
"jsonlines>=4.0.0,<5.0.0",
|
||||||
]
|
]
|
||||||
training = [
|
training = [
|
||||||
@@ -195,7 +195,7 @@ groot = [
|
|||||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
|
|||||||
@@ -256,7 +256,9 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
cli_args = kwargs.pop("cli_args", [])
|
cli_args = kwargs.pop("cli_args", [])
|
||||||
if config_file is not None:
|
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||||
|
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
||||||
|
if config_file is not None and config_file.endswith(".json"):
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||||
|
|||||||
@@ -282,7 +282,11 @@ class VideoDecoderCache:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
if video_path not in self._cache:
|
if video_path not in self._cache:
|
||||||
file_handle = fsspec.open(video_path).__enter__()
|
file_handle = fsspec.open(video_path).__enter__()
|
||||||
|
try:
|
||||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||||
|
except Exception:
|
||||||
|
file_handle.close()
|
||||||
|
raise
|
||||||
self._cache[video_path] = (decoder, file_handle)
|
self._cache[video_path] = (decoder, file_handle)
|
||||||
|
|
||||||
return self._cache[video_path][0]
|
return self._cache[video_path][0]
|
||||||
|
|||||||
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
|||||||
input_dtype = query_states.dtype
|
input_dtype = query_states.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
target_dtype = torch.get_autocast_gpu_dtype()
|
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
|||||||
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
|||||||
input_dtype = query_states.dtype
|
input_dtype = query_states.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
target_dtype = torch.get_autocast_gpu_dtype()
|
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||||
# Handle the case where the model is quantized
|
# Handle the case where the model is quantized
|
||||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
target_dtype = self.config._pre_quantization_dtype
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
|||||||
+51
-39
@@ -52,27 +52,15 @@ import time
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||||
|
|
||||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
if TYPE_CHECKING or _grpc_available:
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
import grpc
|
||||||
|
|
||||||
import grpc # noqa: E402
|
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||||
import torch # noqa: E402
|
from lerobot.transport.utils import (
|
||||||
from torch import nn # noqa: E402
|
|
||||||
from torch.multiprocessing import Queue # noqa: E402
|
|
||||||
|
|
||||||
from lerobot.cameras import opencv # noqa: F401, E402
|
|
||||||
from lerobot.configs import parser # noqa: E402
|
|
||||||
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
|
|
||||||
from lerobot.processor import TransitionKey # noqa: E402
|
|
||||||
from lerobot.robots import so_follower # noqa: F401, E402
|
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
|
|
||||||
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
|
|
||||||
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
|
|
||||||
from lerobot.transport.utils import ( # noqa: E402
|
|
||||||
bytes_to_state_dict,
|
bytes_to_state_dict,
|
||||||
grpc_channel_options,
|
grpc_channel_options,
|
||||||
python_object_to_bytes,
|
python_object_to_bytes,
|
||||||
@@ -80,35 +68,59 @@ from lerobot.transport.utils import ( # noqa: E402
|
|||||||
send_bytes_in_chunks,
|
send_bytes_in_chunks,
|
||||||
transitions_to_bytes,
|
transitions_to_bytes,
|
||||||
)
|
)
|
||||||
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
|
else:
|
||||||
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
|
grpc = None
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
services_pb2 = None
|
||||||
from lerobot.utils.robot_utils import precise_sleep # noqa: E402
|
services_pb2_grpc = None
|
||||||
from lerobot.utils.transition import ( # noqa: E402
|
bytes_to_state_dict = None
|
||||||
|
grpc_channel_options = None
|
||||||
|
python_object_to_bytes = None
|
||||||
|
receive_bytes_in_chunks = None
|
||||||
|
send_bytes_in_chunks = None
|
||||||
|
transitions_to_bytes = None
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.multiprocessing import Queue
|
||||||
|
|
||||||
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.policies import make_policy, make_pre_post_processors
|
||||||
|
from lerobot.processor import TransitionKey
|
||||||
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
|
from lerobot.utils.device_utils import get_safe_torch_device
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
from lerobot.utils.random_utils import set_seed
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
from lerobot.utils.transition import (
|
||||||
Transition,
|
Transition,
|
||||||
move_transition_to_device,
|
move_transition_to_device,
|
||||||
)
|
)
|
||||||
from lerobot.utils.utils import ( # noqa: E402
|
from lerobot.utils.utils import (
|
||||||
TimerManager,
|
TimerManager,
|
||||||
init_logging,
|
init_logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .algorithms.base import RLAlgorithm # noqa: E402
|
from .algorithms.base import RLAlgorithm
|
||||||
from .algorithms.factory import make_algorithm # noqa: E402
|
from .algorithms.factory import make_algorithm
|
||||||
from .gym_manipulator import ( # noqa: E402
|
from .gym_manipulator import (
|
||||||
make_processors,
|
make_processors,
|
||||||
make_robot_env,
|
make_robot_env,
|
||||||
reset_and_build_transition,
|
reset_and_build_transition,
|
||||||
step_env_and_process_transition,
|
step_env_and_process_transition,
|
||||||
)
|
)
|
||||||
from .queue import get_last_item_from_queue # noqa: E402
|
from .queue import get_last_item_from_queue
|
||||||
from .train_rl import TrainRLServerPipelineConfig # noqa: E402
|
from .train_rl import TrainRLServerPipelineConfig
|
||||||
|
|
||||||
# Main entry point
|
# Main entry point
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||||
|
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||||
|
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
display_pid = False
|
display_pid = False
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
@@ -421,7 +433,7 @@ def act_with_policy(
|
|||||||
|
|
||||||
|
|
||||||
def establish_learner_connection(
|
def establish_learner_connection(
|
||||||
stub: services_pb2_grpc.LearnerServiceStub,
|
stub: "services_pb2_grpc.LearnerServiceStub",
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
attempts: int = 30,
|
attempts: int = 30,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -454,7 +466,7 @@ def establish_learner_connection(
|
|||||||
def learner_service_client(
|
def learner_service_client(
|
||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 50051,
|
port: int = 50051,
|
||||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
|
||||||
"""Return a client for the learner service.
|
"""Return a client for the learner service.
|
||||||
|
|
||||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||||
@@ -477,8 +489,8 @@ def receive_policy(
|
|||||||
cfg: TrainRLServerPipelineConfig,
|
cfg: TrainRLServerPipelineConfig,
|
||||||
parameters_queue: Queue,
|
parameters_queue: Queue,
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: "grpc.Channel | None" = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Receive parameters from the learner.
|
"""Receive parameters from the learner.
|
||||||
|
|
||||||
@@ -531,8 +543,8 @@ def send_transitions(
|
|||||||
cfg: TrainRLServerPipelineConfig,
|
cfg: TrainRLServerPipelineConfig,
|
||||||
transitions_queue: Queue,
|
transitions_queue: Queue,
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: "grpc.Channel | None" = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send transitions to the learner.
|
"""Send transitions to the learner.
|
||||||
|
|
||||||
@@ -587,8 +599,8 @@ def send_interactions(
|
|||||||
cfg: TrainRLServerPipelineConfig,
|
cfg: TrainRLServerPipelineConfig,
|
||||||
interactions_queue: Queue,
|
interactions_queue: Queue,
|
||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||||
grpc_channel: grpc.Channel | None = None,
|
grpc_channel: "grpc.Channel | None" = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send interactions to the learner.
|
"""Send interactions to the learner.
|
||||||
|
|
||||||
@@ -646,7 +658,7 @@ def transitions_stream(
|
|||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
transitions_queue: Queue,
|
transitions_queue: Queue,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> Generator[Any, None, services_pb2.Empty]:
|
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = transitions_queue.get(block=True, timeout=timeout)
|
message = transitions_queue.get(block=True, timeout=timeout)
|
||||||
@@ -665,7 +677,7 @@ def interactions_stream(
|
|||||||
shutdown_event: Any, # Event
|
shutdown_event: Any, # Event
|
||||||
interactions_queue: Queue,
|
interactions_queue: Queue,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> Generator[Any, None, services_pb2.Empty]:
|
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = interactions_queue.get(block=True, timeout=timeout)
|
message = interactions_queue.get(block=True, timeout=timeout)
|
||||||
|
|||||||
+41
-36
@@ -51,44 +51,47 @@ import time
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||||
|
|
||||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
if TYPE_CHECKING or _grpc_available:
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
import grpc
|
||||||
|
|
||||||
import grpc # noqa: E402
|
from lerobot.transport import services_pb2_grpc
|
||||||
import torch # noqa: E402
|
else:
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402
|
grpc = None
|
||||||
from safetensors.torch import load_file as load_safetensors # noqa: E402
|
services_pb2_grpc = None
|
||||||
from termcolor import colored # noqa: E402
|
|
||||||
from torch import nn # noqa: E402
|
|
||||||
from torch.multiprocessing import Queue # noqa: E402
|
|
||||||
from torch.optim.optimizer import Optimizer # noqa: E402
|
|
||||||
|
|
||||||
from lerobot.cameras import opencv # noqa: F401, E402
|
import torch
|
||||||
from lerobot.common.train_utils import ( # noqa: E402
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
from termcolor import colored
|
||||||
|
from torch import nn
|
||||||
|
from torch.multiprocessing import Queue
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
|
from lerobot.common.train_utils import (
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
load_training_state as utils_load_training_state,
|
load_training_state as utils_load_training_state,
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
update_last_checkpoint,
|
update_last_checkpoint,
|
||||||
)
|
)
|
||||||
from lerobot.common.wandb_utils import WandBLogger # noqa: E402
|
from lerobot.common.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser # noqa: E402
|
from lerobot.configs import parser
|
||||||
from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402
|
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||||
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
|
from lerobot.policies import make_policy, make_pre_post_processors
|
||||||
from lerobot.robots import so_follower # noqa: F401, E402
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
from lerobot.transport import services_pb2_grpc # noqa: E402
|
from lerobot.transport.utils import (
|
||||||
from lerobot.transport.utils import ( # noqa: E402
|
|
||||||
MAX_MESSAGE_SIZE,
|
MAX_MESSAGE_SIZE,
|
||||||
bytes_to_python_object,
|
bytes_to_python_object,
|
||||||
bytes_to_transitions,
|
bytes_to_transitions,
|
||||||
state_to_bytes,
|
state_to_bytes,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import ( # noqa: E402
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
ALGORITHM_DIR,
|
ALGORITHM_DIR,
|
||||||
CHECKPOINTS_DIR,
|
CHECKPOINTS_DIR,
|
||||||
@@ -97,26 +100,28 @@ from lerobot.utils.constants import ( # noqa: E402
|
|||||||
TRAINING_STATE_DIR,
|
TRAINING_STATE_DIR,
|
||||||
TRAINING_STEP,
|
TRAINING_STEP,
|
||||||
)
|
)
|
||||||
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402
|
from lerobot.utils.device_utils import get_safe_torch_device
|
||||||
from lerobot.utils.io_utils import load_json, write_json # noqa: E402
|
from lerobot.utils.io_utils import load_json, write_json
|
||||||
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
from lerobot.utils.random_utils import set_seed
|
||||||
from lerobot.utils.utils import ( # noqa: E402
|
from lerobot.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
init_logging,
|
init_logging,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .algorithms.base import RLAlgorithm # noqa: E402
|
from .algorithms.base import RLAlgorithm
|
||||||
from .algorithms.factory import make_algorithm # noqa: E402
|
from .algorithms.factory import make_algorithm
|
||||||
from .buffer import ReplayBuffer # noqa: E402
|
from .buffer import ReplayBuffer
|
||||||
from .data_sources import OnlineOfflineMixer # noqa: E402
|
from .data_sources import OnlineOfflineMixer
|
||||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402
|
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||||
from .train_rl import TrainRLServerPipelineConfig # noqa: E402
|
from .train_rl import TrainRLServerPipelineConfig
|
||||||
from .trainer import RLTrainer # noqa: E402
|
from .trainer import RLTrainer
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train_cli(cfg: TrainRLServerPipelineConfig):
|
def train_cli(cfg: TrainRLServerPipelineConfig):
|
||||||
|
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||||
|
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
|||||||
@@ -18,22 +18,32 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from multiprocessing import Event, Queue
|
from multiprocessing import Event, Queue
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import _grpc_available
|
||||||
|
|
||||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
from .queue import get_last_item_from_queue
|
||||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
|
||||||
|
|
||||||
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
|
if TYPE_CHECKING or _grpc_available:
|
||||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402
|
import grpc
|
||||||
|
|
||||||
from .queue import get_last_item_from_queue # noqa: E402
|
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||||
|
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||||
|
|
||||||
|
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
|
||||||
|
else:
|
||||||
|
grpc = None
|
||||||
|
services_pb2 = None
|
||||||
|
services_pb2_grpc = None
|
||||||
|
receive_bytes_in_chunks = None
|
||||||
|
send_bytes_in_chunks = None
|
||||||
|
_ServicerBase = object
|
||||||
|
|
||||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||||
SHUTDOWN_TIMEOUT = 10
|
SHUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
|
|
||||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
class LearnerService(_ServicerBase):
|
||||||
"""
|
"""
|
||||||
Implementation of the LearnerService gRPC service
|
Implementation of the LearnerService gRPC service
|
||||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||||
@@ -56,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
|||||||
self.interaction_message_queue = interaction_message_queue
|
self.interaction_message_queue = interaction_message_queue
|
||||||
self.queue_get_timeout = queue_get_timeout
|
self.queue_get_timeout = queue_get_timeout
|
||||||
|
|
||||||
def StreamParameters(self, request, context): # noqa: N802
|
def StreamParameters( # noqa: N802
|
||||||
|
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
|
||||||
|
):
|
||||||
# TODO: authorize the request
|
# TODO: authorize the request
|
||||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||||
|
|
||||||
@@ -91,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
|||||||
logging.info("[LEARNER] Stream parameters finished")
|
logging.info("[LEARNER] Stream parameters finished")
|
||||||
return services_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
|
||||||
# TODO: authorize the request
|
# TODO: authorize the request
|
||||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||||
|
|
||||||
@@ -105,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
|||||||
logging.debug("[LEARNER] Finished receiving transitions")
|
logging.debug("[LEARNER] Finished receiving transitions")
|
||||||
return services_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
|
||||||
# TODO: authorize the request
|
# TODO: authorize the request
|
||||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||||
|
|
||||||
@@ -119,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
|||||||
logging.debug("[LEARNER] Finished receiving interactions")
|
logging.debug("[LEARNER] Finished receiving interactions")
|
||||||
return services_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def Ready(self, request, context): # noqa: N802
|
def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
|
||||||
return services_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class LeKiwiConfig(RobotConfig):
|
|||||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||||
|
|
||||||
# Set to `True` for backward compatibility with previous policies/dataset
|
# Set to `True` for backward compatibility with previous policies/dataset
|
||||||
use_degrees: bool = False
|
use_degrees: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
Reference in New Issue
Block a user