Compare commits

...

9 Commits

Author SHA1 Message Date
Khalil Meftah 1873668521 Merge remote-tracking branch 'origin/main' into user/khalil-meftah/2026-02-16-rl-stack-refactor
# Conflicts:
#	uv.lock
2026-05-11 15:06:16 +02:00
Khalil Meftah 28ee32f721 update uv.lock 2026-05-11 14:52:31 +02:00
Steven Palma d28316d1b6 chore(rl): manage import pattern in actor (#3564)
* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
2026-05-11 14:05:31 +02:00
Jash Shah 9e83510c99 fix(datasets): close file handle on VideoDecoder init failure in cache (#3542)
If VideoDecoder() raises during initialization, the fsspec file handle
was leaked since it was opened via __enter__() but never closed on the
exception path. Now explicitly closes the handle before re-raising.
2026-05-10 17:30:37 +02:00
Anthony Shoumikhin 1f7b03f5f2 chore(deps): allow torch 2.11/2.12 and fix autocast deprecation (#3435)
* chore(deps): allow torch 2.11/2.12 and fix autocast deprecation

- Bump torch to >=2.7,<2.13 (was <2.11), torchvision to <0.28 (was <0.26),
  and torchcodec to <0.13 (was <0.11) to allow installs against the latest
  stable torch 2.11 and the upcoming 2.12 line.
- Replace removed torch.get_autocast_gpu_dtype() with torch.get_autocast_dtype("cuda")
  in Florence2 and Qwen2.5-VL-MoE FlashAttention paths (the former is removed in 2.11+).
- Refresh uv.lock for the new resolution (torch 2.11.0+cu130, torchvision 0.26.0+cu130,
  torchcodec 0.11.1, full CUDA 13 stack).

Verified locally with `uv sync --locked` from a clean .venv and the lerobot
test suite (pytest -n 8 --dist=loadfile --timeout=300). Failure set is
identical to the pre-bump baseline: 18 pre-existing failures
(test_sac_policy*, test_pi0_rtc*, test_pi05_rtc*, test_replay_buffer*),
0 new, 0 fixed.

AI assistance: this change was authored with Claude Code per AI_POLICY.md.

* fix(policies): use device-agnostic autocast dtype lookup

Pass query_states.device.type to torch.get_autocast_dtype() instead of
hardcoding 'cuda', so the cast matches the active autocast context when
running under CPU/MPS/XPU autocast.

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-10 13:05:35 +02:00
Steven Palma cb8edf17e6 chore(dependencies): update uv.lock (#3475) 2026-05-10 12:24:22 +02:00
Steven Palma 5699f6cbf4 chore(ci): disable auto-stale (#3550) 2026-05-10 11:49:31 +02:00
masato-ka 0e6114ac36 fix(train): restrict legacy RA-BC migration to JSON checkpoints only (#3490)
* fix(train): restrict legacy RA-BC migration to JSON checkpoints only

_migrate_legacy_rabc_fields was called for all config files, causing
json.load to raise DecodeError when a YAML/TOML config was passed to
lerobot-train for a new training run. Guard the block with an
.endswith(".json") check so migration only runs when resuming from
a JSON checkpoint.
2026-05-08 20:27:01 +02:00
Steven Palma c8ce413d73 fix(robots): allign lekiwi default with so100 use_degrees (#3531) 2026-05-07 17:52:34 +02:00
11 changed files with 782 additions and 640 deletions
+2 -2
View File
@@ -19,8 +19,8 @@ on:
workflow_dispatch: workflow_dispatch:
# Runs at 02:00 # Runs at 02:00
schedule: # schedule:
- cron: "0 2 * * *" # - cron: "0 2 * * *"
env: env:
CLOSE_ISSUE_MESSAGE: > CLOSE_ISSUE_MESSAGE: >
+4 -4
View File
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [ dependencies = [
# Core ML # Core ML
"torch>=2.7,<2.11.0", "torch>=2.7,<2.13.0",
"torchvision>=0.22.0,<0.26.0", "torchvision>=0.22.0,<0.28.0",
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless. "numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"opencv-python-headless>=4.9.0,<4.14.0", "opencv-python-headless>=4.9.0,<4.14.0",
"Pillow>=10.0.0,<13.0.0", "Pillow>=10.0.0,<13.0.0",
@@ -99,7 +99,7 @@ dataset = [
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets "pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets "pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]", "lerobot[av-dep]",
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10). "torchcodec>=0.3.0,<0.13.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10), 0.11 needs torch==2.11, 0.12 needs torch==2.12.
"jsonlines>=4.0.0,<5.0.0", "jsonlines>=4.0.0,<5.0.0",
] ]
training = [ training = [
@@ -195,7 +195,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features # Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
+3 -1
View File
@@ -256,7 +256,9 @@ class TrainPipelineConfig(HubMixin):
) from e ) from e
cli_args = kwargs.pop("cli_args", []) cli_args = kwargs.pop("cli_args", [])
if config_file is not None: # Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
if config_file is not None and config_file.endswith(".json"):
with open(config_file) as f: with open(config_file) as f:
config = json.load(f) config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config) migrated_config = _migrate_legacy_rabc_fields(config)
+4
View File
@@ -282,7 +282,11 @@ class VideoDecoderCache:
with self._lock: with self._lock:
if video_path not in self._cache: if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__() file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate") decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle) self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0] return self._cache[video_path][0]
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype() target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized # Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"): elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype target_dtype = self.config._pre_quantization_dtype
+51 -39
View File
@@ -52,27 +52,15 @@ import time
from collections.abc import Generator from collections.abc import Generator
from functools import lru_cache from functools import lru_cache
from queue import Empty from queue import Empty
from typing import Any from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import _grpc_available, require_package
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. if TYPE_CHECKING or _grpc_available:
require_package("grpcio", extra="hilserl", import_name="grpc") import grpc
import grpc # noqa: E402 from lerobot.transport import services_pb2, services_pb2_grpc
import torch # noqa: E402 from lerobot.transport.utils import (
from torch import nn # noqa: E402
from torch.multiprocessing import Queue # noqa: E402
from lerobot.cameras import opencv # noqa: F401, E402
from lerobot.configs import parser # noqa: E402
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402
from lerobot.processor import TransitionKey # noqa: E402
from lerobot.robots import so_follower # noqa: F401, E402
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402
from lerobot.transport.utils import ( # noqa: E402
bytes_to_state_dict, bytes_to_state_dict,
grpc_channel_options, grpc_channel_options,
python_object_to_bytes, python_object_to_bytes,
@@ -80,35 +68,59 @@ from lerobot.transport.utils import ( # noqa: E402
send_bytes_in_chunks, send_bytes_in_chunks,
transitions_to_bytes, transitions_to_bytes,
) )
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 else:
from lerobot.utils.process import ProcessSignalHandler # noqa: E402 grpc = None
from lerobot.utils.random_utils import set_seed # noqa: E402 services_pb2 = None
from lerobot.utils.robot_utils import precise_sleep # noqa: E402 services_pb2_grpc = None
from lerobot.utils.transition import ( # noqa: E402 bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
import torch
from torch import nn
from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
Transition, Transition,
move_transition_to_device, move_transition_to_device,
) )
from lerobot.utils.utils import ( # noqa: E402 from lerobot.utils.utils import (
TimerManager, TimerManager,
init_logging, init_logging,
) )
from .algorithms.base import RLAlgorithm # noqa: E402 from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm # noqa: E402 from .algorithms.factory import make_algorithm
from .gym_manipulator import ( # noqa: E402 from .gym_manipulator import (
make_processors, make_processors,
make_robot_env, make_robot_env,
reset_and_build_transition, reset_and_build_transition,
step_env_and_process_transition, step_env_and_process_transition,
) )
from .queue import get_last_item_from_queue # noqa: E402 from .queue import get_last_item_from_queue
from .train_rl import TrainRLServerPipelineConfig # noqa: E402 from .train_rl import TrainRLServerPipelineConfig
# Main entry point # Main entry point
@parser.wrap() @parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig): def actor_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate() cfg.validate()
display_pid = False display_pid = False
if not use_threads(cfg): if not use_threads(cfg):
@@ -421,7 +433,7 @@ def act_with_policy(
def establish_learner_connection( def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub, stub: "services_pb2_grpc.LearnerServiceStub",
shutdown_event: Any, # Event shutdown_event: Any, # Event
attempts: int = 30, attempts: int = 30,
) -> bool: ) -> bool:
@@ -454,7 +466,7 @@ def establish_learner_connection(
def learner_service_client( def learner_service_client(
host: str = "127.0.0.1", host: str = "127.0.0.1",
port: int = 50051, port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: ) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
"""Return a client for the learner service. """Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
@@ -477,8 +489,8 @@ def receive_policy(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue, parameters_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> None: ) -> None:
"""Receive parameters from the learner. """Receive parameters from the learner.
@@ -531,8 +543,8 @@ def send_transitions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue, transitions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> None: ) -> None:
"""Send transitions to the learner. """Send transitions to the learner.
@@ -587,8 +599,8 @@ def send_interactions(
cfg: TrainRLServerPipelineConfig, cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue, interactions_queue: Queue,
shutdown_event: Any, # Event shutdown_event: Any, # Event
learner_client: services_pb2_grpc.LearnerServiceStub | None = None, learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: grpc.Channel | None = None, grpc_channel: "grpc.Channel | None" = None,
) -> None: ) -> None:
"""Send interactions to the learner. """Send interactions to the learner.
@@ -646,7 +658,7 @@ def transitions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
transitions_queue: Queue, transitions_queue: Queue,
timeout: float, timeout: float,
) -> Generator[Any, None, services_pb2.Empty]: ) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = transitions_queue.get(block=True, timeout=timeout) message = transitions_queue.get(block=True, timeout=timeout)
@@ -665,7 +677,7 @@ def interactions_stream(
shutdown_event: Any, # Event shutdown_event: Any, # Event
interactions_queue: Queue, interactions_queue: Queue,
timeout: float, timeout: float,
) -> Generator[Any, None, services_pb2.Empty]: ) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set(): while not shutdown_event.is_set():
try: try:
message = interactions_queue.get(block=True, timeout=timeout) message = interactions_queue.get(block=True, timeout=timeout)
+41 -36
View File
@@ -51,44 +51,47 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Any from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import _grpc_available, require_package
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. if TYPE_CHECKING or _grpc_available:
require_package("grpcio", extra="hilserl", import_name="grpc") import grpc
import grpc # noqa: E402 from lerobot.transport import services_pb2_grpc
import torch # noqa: E402 else:
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE # noqa: E402 grpc = None
from safetensors.torch import load_file as load_safetensors # noqa: E402 services_pb2_grpc = None
from termcolor import colored # noqa: E402
from torch import nn # noqa: E402
from torch.multiprocessing import Queue # noqa: E402
from torch.optim.optimizer import Optimizer # noqa: E402
from lerobot.cameras import opencv # noqa: F401, E402 import torch
from lerobot.common.train_utils import ( # noqa: E402 from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.cameras import opencv # noqa: F401
from lerobot.common.train_utils import (
get_step_checkpoint_dir, get_step_checkpoint_dir,
load_training_state as utils_load_training_state, load_training_state as utils_load_training_state,
save_checkpoint, save_checkpoint,
update_last_checkpoint, update_last_checkpoint,
) )
from lerobot.common.wandb_utils import WandBLogger # noqa: E402 from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser # noqa: E402 from lerobot.configs import parser
from lerobot.datasets import LeRobotDataset, make_dataset # noqa: E402 from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy, make_pre_post_processors # noqa: E402 from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.robots import so_follower # noqa: F401, E402 from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401, E402 from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents # noqa: E402 from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc # noqa: E402 from lerobot.transport.utils import (
from lerobot.transport.utils import ( # noqa: E402
MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE,
bytes_to_python_object, bytes_to_python_object,
bytes_to_transitions, bytes_to_transitions,
state_to_bytes, state_to_bytes,
) )
from lerobot.utils.constants import ( # noqa: E402 from lerobot.utils.constants import (
ACTION, ACTION,
ALGORITHM_DIR, ALGORITHM_DIR,
CHECKPOINTS_DIR, CHECKPOINTS_DIR,
@@ -97,26 +100,28 @@ from lerobot.utils.constants import ( # noqa: E402
TRAINING_STATE_DIR, TRAINING_STATE_DIR,
TRAINING_STEP, TRAINING_STEP,
) )
from lerobot.utils.device_utils import get_safe_torch_device # noqa: E402 from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.io_utils import load_json, write_json # noqa: E402 from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.process import ProcessSignalHandler # noqa: E402 from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed # noqa: E402 from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( # noqa: E402 from lerobot.utils.utils import (
format_big_number, format_big_number,
init_logging, init_logging,
) )
from .algorithms.base import RLAlgorithm # noqa: E402 from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm # noqa: E402 from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer # noqa: E402 from .buffer import ReplayBuffer
from .data_sources import OnlineOfflineMixer # noqa: E402 from .data_sources import OnlineOfflineMixer
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService # noqa: E402 from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .train_rl import TrainRLServerPipelineConfig # noqa: E402 from .train_rl import TrainRLServerPipelineConfig
from .trainer import RLTrainer # noqa: E402 from .trainer import RLTrainer
@parser.wrap() @parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig): def train_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
+23 -11
View File
@@ -18,22 +18,32 @@
import logging import logging
import time import time
from multiprocessing import Event, Queue from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import _grpc_available
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing. from .queue import get_last_item_from_queue
require_package("grpcio", extra="hilserl", import_name="grpc")
from lerobot.transport import services_pb2, services_pb2_grpc # noqa: E402 if TYPE_CHECKING or _grpc_available:
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks # noqa: E402 import grpc
from .queue import get_last_item_from_queue # noqa: E402 from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10 SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer): class LearnerService(_ServicerBase):
""" """
Implementation of the LearnerService gRPC service Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -56,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
self.interaction_message_queue = interaction_message_queue self.interaction_message_queue = interaction_message_queue
self.queue_get_timeout = queue_get_timeout self.queue_get_timeout = queue_get_timeout
def StreamParameters(self, request, context): # noqa: N802 def StreamParameters( # noqa: N802
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
):
# TODO: authorize the request # TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor") logging.info("[LEARNER] Received request to stream parameters from the Actor")
@@ -91,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.info("[LEARNER] Stream parameters finished") logging.info("[LEARNER] Stream parameters finished")
return services_pb2.Empty() return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802 def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request # TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor") logging.info("[LEARNER] Received request to receive transitions from the Actor")
@@ -105,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving transitions") logging.debug("[LEARNER] Finished receiving transitions")
return services_pb2.Empty() return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802 def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request # TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor") logging.info("[LEARNER] Received request to receive interactions from the Actor")
@@ -119,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving interactions") logging.debug("[LEARNER] Finished receiving interactions")
return services_pb2.Empty() return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802 def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
return services_pb2.Empty() return services_pb2.Empty()
+1 -1
View File
@@ -46,7 +46,7 @@ class LeKiwiConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
# Set to `True` for backward compatibility with previous policies/dataset # Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False use_degrees: bool = True
@dataclass @dataclass
Generated
+644 -537
View File
File diff suppressed because it is too large Load Diff