mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix policy imports
This commit is contained in:
@@ -1,6 +1,16 @@
|
||||
# Async inference server/client.
|
||||
# Requires: lerobot[async]
|
||||
"""
|
||||
Async inference server/client.
|
||||
|
||||
Requires: ``pip install 'lerobot[async]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.async_inference.policy_server import ...
|
||||
from lerobot.async_inference.robot_client import ...
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("grpcio", extra="async", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -5,4 +5,12 @@ Unlike ``lerobot.utils`` (which must remain dependency-free), modules here
|
||||
are allowed to import from ``lerobot.policies``, ``lerobot.processor``,
|
||||
``lerobot.configs``, etc. They are deliberately NOT re-exported from the
|
||||
top-level ``lerobot`` package.
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.common.control_utils import predict_action, ...
|
||||
from lerobot.common.train_utils import save_checkpoint, ...
|
||||
from lerobot.common.wandb_utils import WandBLogger, ...
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -11,3 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Data processing utilities (annotation tools, dataset transformations).
|
||||
|
||||
Available sub-modules (import directly)::
|
||||
|
||||
from lerobot.data_processing.sarm_annotations import ...
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -11,3 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM subtask annotation tools.
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.data_processing.sarm_annotations.subtask_annotation import ...
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -12,6 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE: gymnasium is currently a core dependency but is a candidate for moving to an
|
||||
# optional extra in the future. This guard is here to ensure a clear error message
|
||||
# if/when that transition happens.
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("gymnasium", extra="evaluation", import_name="gymnasium")
|
||||
|
||||
from .configs import AlohaEnv, EnvConfig, HILSerlRobotEnvConfig, HubEnvConfig, PushtEnv
|
||||
from .factory import make_env, make_env_config, make_env_pre_post_processors
|
||||
from .utils import check_env_attributes_and_types, close_envs, env_to_policy_features, preprocess_observation
|
||||
|
||||
@@ -22,11 +22,10 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .rtc import ActionInterpolator as ActionInterpolator
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import (
|
||||
SmolVLANewLineProcessor as SmolVLANewLineProcessor, # noqa: F401 - registers with ProcessorStepRegistry
|
||||
)
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -46,6 +45,8 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"RewardClassifierConfig",
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
|
||||
__all__ = ["SARMConfig"]
|
||||
__all__ = ["SARMConfig", "SARMRewardModel"]
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from .configuration_smolvla import SmolVLAConfig
|
||||
from .processor_smolvla import (
|
||||
SmolVLANewLineProcessor,
|
||||
make_smolvla_pre_post_processors,
|
||||
)
|
||||
from .modeling_smolvla import SmolVLAPolicy
|
||||
from .processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
__all__ = ["SmolVLAConfig", "SmolVLANewLineProcessor", "make_smolvla_pre_post_processors"]
|
||||
__all__ = ["SmolVLAConfig", "SmolVLAPolicy", "make_smolvla_pre_post_processors"]
|
||||
|
||||
@@ -18,15 +18,13 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NewLineTaskProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
@@ -71,7 +69,7 @@ def make_smolvla_pre_post_processors(
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
SmolVLANewLineProcessor(),
|
||||
NewLineTaskProcessorStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
padding=config.pad_language_to,
|
||||
@@ -103,41 +101,3 @@ def make_smolvla_pre_post_processors(
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
A processor step that ensures the 'task' description ends with a newline character.
|
||||
|
||||
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
|
||||
newline at the end of the prompt. It handles both single string tasks and lists
|
||||
of string tasks.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -13,21 +13,27 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
require_package("transformers", extra="smolvla")
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
AutoModelForImageTextToText = None
|
||||
AutoProcessor = None
|
||||
SmolVLMForConditionalGeneration = None
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
@@ -78,6 +84,7 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
device: str = "auto",
|
||||
):
|
||||
super().__init__()
|
||||
require_package("transformers", extra="smolvla")
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_wall_x import WallXConfig
|
||||
from .modeling_wall_x import WallXPolicy
|
||||
from .processor_wall_x import make_wall_x_pre_post_processors
|
||||
|
||||
__all__ = ["WallXConfig", "make_wall_x_pre_post_processors"]
|
||||
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
|
||||
|
||||
@@ -34,40 +34,28 @@ lerobot-train \
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from os import PathLike
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("transformers", extra="wallx")
|
||||
require_package("peft", extra="wallx")
|
||||
require_package("torchdiffeq", extra="wallx")
|
||||
require_package("qwen-vl-utils", extra="wallx", import_name="qwen_vl_utils")
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from PIL import Image
|
||||
from qwen_vl_utils.vision_process import smart_resize
|
||||
from torch import Tensor
|
||||
from torch.distributions import Beta
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torchdiffeq import odeint
|
||||
from transformers import AutoProcessor, BatchFeature
|
||||
from transformers.cache_utils import (
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
from transformers.utils import is_torchdynamo_compiling, logging
|
||||
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import (
|
||||
_peft_available,
|
||||
_transformers_available,
|
||||
is_package_available,
|
||||
require_package,
|
||||
)
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import populate_queues
|
||||
@@ -82,12 +70,45 @@ from .constant import (
|
||||
RESOLUTION,
|
||||
TOKENIZER_MAX_LENGTH,
|
||||
)
|
||||
from .qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
||||
from .qwen_model.qwen2_5_vl_moe import (
|
||||
Qwen2_5_VisionTransformerPretrainedModel,
|
||||
Qwen2_5_VLACausalLMOutputWithPast,
|
||||
Qwen2_5_VLMoEModel,
|
||||
|
||||
_torchdiffeq_available = is_package_available("torchdiffeq")
|
||||
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
|
||||
_wallx_deps_available = (
|
||||
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
|
||||
)
|
||||
|
||||
if TYPE_CHECKING or _wallx_deps_available:
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from qwen_vl_utils.vision_process import smart_resize
|
||||
from torchdiffeq import odeint
|
||||
from transformers import AutoProcessor, BatchFeature
|
||||
from transformers.cache_utils import StaticCache
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
from transformers.utils import is_torchdynamo_compiling
|
||||
|
||||
from .qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
|
||||
from .qwen_model.qwen2_5_vl_moe import (
|
||||
Qwen2_5_VisionTransformerPretrainedModel,
|
||||
Qwen2_5_VLACausalLMOutputWithPast,
|
||||
Qwen2_5_VLMoEModel,
|
||||
)
|
||||
else:
|
||||
LoraConfig = None
|
||||
get_peft_model = None
|
||||
smart_resize = None
|
||||
odeint = None
|
||||
AutoProcessor = None
|
||||
BatchFeature = None
|
||||
StaticCache = None
|
||||
Qwen2_5_VLForConditionalGeneration = None
|
||||
is_torchdynamo_compiling = None
|
||||
Qwen2_5_VLConfig = None
|
||||
Qwen2_5_VisionTransformerPretrainedModel = None
|
||||
Qwen2_5_VLACausalLMOutputWithPast = None
|
||||
Qwen2_5_VLMoEModel = None
|
||||
|
||||
from .utils import (
|
||||
get_wallx_normal_text,
|
||||
preprocesser_call,
|
||||
@@ -95,7 +116,7 @@ from .utils import (
|
||||
replace_action_token,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
@@ -262,7 +283,13 @@ class ActionHead(nn.Module):
|
||||
return self.propri_proj(proprioception)
|
||||
|
||||
|
||||
class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
# Conditional base: when transformers is unavailable the class still parses
|
||||
# (inheriting from nn.Module) but cannot be instantiated—require_package in
|
||||
# WallXPolicy.__init__ gives the user a clear error before that happens.
|
||||
_Qwen2_5_VLForAction_Base = Qwen2_5_VLForConditionalGeneration if _wallx_deps_available else nn.Module
|
||||
|
||||
|
||||
class Qwen2_5_VLMoEForAction(_Qwen2_5_VLForAction_Base):
|
||||
"""
|
||||
Qwen2.5 Vision-Language Mixture of Experts model for action processing.
|
||||
|
||||
@@ -1717,6 +1744,10 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
name = "wall_x"
|
||||
|
||||
def __init__(self, config: WallXConfig, **kwargs):
|
||||
require_package("transformers", extra="wallx")
|
||||
require_package("peft", extra="wallx")
|
||||
require_package("torchdiffeq", extra="wallx")
|
||||
require_package("qwen-vl-utils", extra="wallx", import_name="qwen_vl_utils")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -25,10 +25,16 @@ import random
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import BatchFeature
|
||||
else:
|
||||
BatchFeature = None
|
||||
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .configuration_xvla import XVLAConfig
|
||||
from .modeling_xvla import XVLAPolicy
|
||||
from .processor_xvla import (
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
XVLAImageNetNormalizeProcessorStep,
|
||||
@@ -7,6 +8,7 @@ from .processor_xvla import (
|
||||
|
||||
__all__ = [
|
||||
"XVLAConfig",
|
||||
"XVLAPolicy",
|
||||
"XVLAAddDomainIdProcessorStep",
|
||||
"XVLAImageNetNormalizeProcessorStep",
|
||||
"XVLAImageToFloatProcessorStep",
|
||||
|
||||
@@ -23,10 +23,7 @@ import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("transformers", extra="xvla")
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
@@ -34,15 +31,22 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy, T
|
||||
from ..utils import populate_queues
|
||||
from .action_hub import build_action_space
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_xvla import XVLAConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .soft_transformer import SoftPromptedTransformer
|
||||
|
||||
# Florence2 config and modeling depend on transformers
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
else:
|
||||
Florence2Config = None
|
||||
Florence2ForConditionalGeneration = None
|
||||
|
||||
|
||||
class XVLAModel(nn.Module):
|
||||
"""
|
||||
@@ -278,6 +282,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
name = "xvla"
|
||||
|
||||
def __init__(self, config: XVLAConfig, **kwargs):
|
||||
require_package("transformers", extra="xvla")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
florence_config = config.get_florence_config()
|
||||
|
||||
@@ -56,6 +56,7 @@ from .hil_processor import (
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .newline_task_processor import NewLineTaskProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
from .pipeline import (
|
||||
@@ -119,6 +120,7 @@ __all__ = [
|
||||
"RelativeActionsProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NewLineTaskProcessorStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
"ObservationProcessorStep",
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .pipeline import ComplementaryDataProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
# NOTE: The registry name "smolvla_new_line_processor" is kept for backward compatibility
|
||||
# with serialized processor configs that reference this name.
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class NewLineTaskProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
A processor step that ensures the 'task' description ends with a newline character.
|
||||
|
||||
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
|
||||
newline at the end of the prompt. It handles both single string tasks and lists
|
||||
of string tasks.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -20,9 +20,11 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
|
||||
from lerobot.types import PolicyAction, RobotAction
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("robot_action_to_policy_action_processor")
|
||||
|
||||
@@ -1,6 +1,20 @@
|
||||
# Reinforcement learning modules.
|
||||
# Requires: lerobot[hilserl]
|
||||
"""
|
||||
Reinforcement learning modules.
|
||||
|
||||
Requires: ``pip install 'lerobot[hilserl]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.rl.actor import ...
|
||||
from lerobot.rl.learner import ...
|
||||
from lerobot.rl.learner_service import ...
|
||||
from lerobot.rl.buffer import ...
|
||||
from lerobot.rl.eval_policy import ...
|
||||
from lerobot.rl.gym_manipulator import ...
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -66,6 +66,7 @@ from lerobot.common.train_utils import (
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||
@@ -99,7 +100,6 @@ from lerobot.utils.utils import (
|
||||
from .buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||
from .process import ProcessSignalHandler
|
||||
from .wandb_utils import WandBLogger
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -151,7 +151,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
|
||||
|
||||
# Setup WandB logging if enabled
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
from .wandb_utils import WandBLogger
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
|
||||
@@ -38,10 +38,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.cameras import ColorMode
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.cameras.opencv.camera_opencv import OpenCVCamera
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig
|
||||
from lerobot.cameras.realsense.camera_realsense import RealSenseCamera
|
||||
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
|
||||
@@ -35,13 +35,13 @@ from lerobot.common.train_utils import (
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
@@ -1,2 +1,11 @@
|
||||
# gRPC transport layer for async inference.
|
||||
# Requires: lerobot[grpcio-dep]
|
||||
"""
|
||||
gRPC transport layer for async inference.
|
||||
|
||||
Requires: ``pip install 'lerobot[grpcio-dep]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.transport.utils import ...
|
||||
"""
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
Reference in New Issue
Block a user