fix policy imports

This commit is contained in:
Steven Palma
2026-04-11 20:39:03 +02:00
parent af0d72bd42
commit c9636bb53f
26 changed files with 266 additions and 127 deletions
+12 -2
View File
@@ -1,6 +1,16 @@
# Async inference server/client.
# Requires: lerobot[async]
"""
Async inference server/client.
Requires: ``pip install 'lerobot[async]'``
Available modules (import directly)::
from lerobot.async_inference.policy_server import ...
from lerobot.async_inference.robot_client import ...
"""
from lerobot.utils.import_utils import require_package
require_package("grpcio", extra="async", import_name="grpc")
__all__: list[str] = []
+8
View File
@@ -5,4 +5,12 @@ Unlike ``lerobot.utils`` (which must remain dependency-free), modules here
are allowed to import from ``lerobot.policies``, ``lerobot.processor``,
``lerobot.configs``, etc. They are deliberately NOT re-exported from the
top-level ``lerobot`` package.
Available modules (import directly)::
from lerobot.common.control_utils import predict_action, ...
from lerobot.common.train_utils import save_checkpoint, ...
from lerobot.common.wandb_utils import WandBLogger, ...
"""
__all__: list[str] = []
+10
View File
@@ -11,3 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data processing utilities (annotation tools, dataset transformations).
Available sub-modules (import directly)::
from lerobot.data_processing.sarm_annotations import ...
"""
__all__: list[str] = []
@@ -11,3 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM subtask annotation tools.
Available modules (import directly)::
from lerobot.data_processing.sarm_annotations.subtask_annotation import ...
"""
__all__: list[str] = []
+7
View File
@@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: gymnasium is currently a core dependency but is a candidate for moving to an
# optional extra in the future. This guard is here to ensure a clear error message
# if/when that transition happens.
from lerobot.utils.import_utils import require_package
require_package("gymnasium", extra="evaluation", import_name="gymnasium")
from .configs import AlohaEnv, EnvConfig, HILSerlRobotEnvConfig, HubEnvConfig, PushtEnv
from .factory import make_env, make_env_config, make_env_pre_post_processors
from .utils import check_env_attributes_and_types, close_envs, env_to_policy_features, preprocess_observation
+4 -3
View File
@@ -22,11 +22,10 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .rtc import ActionInterpolator as ActionInterpolator
from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import (
SmolVLANewLineProcessor as SmolVLANewLineProcessor, # noqa: F401 - registers with ProcessorStepRegistry
)
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
@@ -46,6 +45,8 @@ __all__ = [
"PI0Config",
"PI0FastConfig",
"PI05Config",
"RewardClassifierConfig",
"SACConfig",
"SARMConfig",
"SmolVLAConfig",
"TDMPCConfig",
+2 -1
View File
@@ -1,3 +1,4 @@
from .configuration_sarm import SARMConfig
from .modeling_sarm import SARMRewardModel
__all__ = ["SARMConfig"]
__all__ = ["SARMConfig", "SARMRewardModel"]
+3 -5
View File
@@ -1,7 +1,5 @@
from .configuration_smolvla import SmolVLAConfig
from .processor_smolvla import (
SmolVLANewLineProcessor,
make_smolvla_pre_post_processors,
)
from .modeling_smolvla import SmolVLAPolicy
from .processor_smolvla import make_smolvla_pre_post_processors
__all__ = ["SmolVLAConfig", "SmolVLANewLineProcessor", "make_smolvla_pre_post_processors"]
__all__ = ["SmolVLAConfig", "SmolVLAPolicy", "make_smolvla_pre_post_processors"]
@@ -18,15 +18,13 @@ from typing import Any
import torch
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NewLineTaskProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
@@ -71,7 +69,7 @@ def make_smolvla_pre_post_processors(
input_steps = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
SmolVLANewLineProcessor(),
NewLineTaskProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.vlm_model_name,
padding=config.pad_language_to,
@@ -103,41 +101,3 @@ def make_smolvla_pre_post_processors(
to_output=transition_to_policy_action,
),
)
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
"""
A processor step that ensures the 'task' description ends with a newline character.
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
newline at the end of the prompt. It handles both single string tasks and lists
of string tasks.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@@ -13,21 +13,27 @@
# limitations under the License.
import copy
from typing import TYPE_CHECKING
import torch
from torch import nn
from lerobot.utils.import_utils import require_package
from lerobot.utils.import_utils import _transformers_available, require_package
require_package("transformers", extra="smolvla")
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
if TYPE_CHECKING or _transformers_available:
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
else:
AutoConfig = None
AutoModel = None
AutoModelForImageTextToText = None
AutoProcessor = None
SmolVLMForConditionalGeneration = None
def apply_rope(x, positions, max_wavelength=10_000):
@@ -78,6 +84,7 @@ class SmolVLMWithExpertModel(nn.Module):
device: str = "auto",
):
super().__init__()
require_package("transformers", extra="smolvla")
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
+2 -1
View File
@@ -15,6 +15,7 @@
# limitations under the License.
from .configuration_wall_x import WallXConfig
from .modeling_wall_x import WallXPolicy
from .processor_wall_x import make_wall_x_pre_post_processors
__all__ = ["WallXConfig", "make_wall_x_pre_post_processors"]
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
+58 -27
View File
@@ -34,40 +34,28 @@ lerobot-train \
```
"""
import logging
import math
from collections import deque
from os import PathLike
from typing import Any
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from lerobot.utils.import_utils import require_package
require_package("transformers", extra="wallx")
require_package("peft", extra="wallx")
require_package("torchdiffeq", extra="wallx")
require_package("qwen-vl-utils", extra="wallx", import_name="qwen_vl_utils")
from peft import LoraConfig, get_peft_model
from PIL import Image
from qwen_vl_utils.vision_process import smart_resize
from torch import Tensor
from torch.distributions import Beta
from torch.nn import CrossEntropyLoss
from torchdiffeq import odeint
from transformers import AutoProcessor, BatchFeature
from transformers.cache_utils import (
StaticCache,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from transformers.utils import is_torchdynamo_compiling, logging
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import (
_peft_available,
_transformers_available,
is_package_available,
require_package,
)
from ..pretrained import PreTrainedPolicy
from ..utils import populate_queues
@@ -82,12 +70,45 @@ from .constant import (
RESOLUTION,
TOKENIZER_MAX_LENGTH,
)
from .qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from .qwen_model.qwen2_5_vl_moe import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLACausalLMOutputWithPast,
Qwen2_5_VLMoEModel,
_torchdiffeq_available = is_package_available("torchdiffeq")
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
_wallx_deps_available = (
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
)
if TYPE_CHECKING or _wallx_deps_available:
from peft import LoraConfig, get_peft_model
from qwen_vl_utils.vision_process import smart_resize
from torchdiffeq import odeint
from transformers import AutoProcessor, BatchFeature
from transformers.cache_utils import StaticCache
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from transformers.utils import is_torchdynamo_compiling
from .qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from .qwen_model.qwen2_5_vl_moe import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLACausalLMOutputWithPast,
Qwen2_5_VLMoEModel,
)
else:
LoraConfig = None
get_peft_model = None
smart_resize = None
odeint = None
AutoProcessor = None
BatchFeature = None
StaticCache = None
Qwen2_5_VLForConditionalGeneration = None
is_torchdynamo_compiling = None
Qwen2_5_VLConfig = None
Qwen2_5_VisionTransformerPretrainedModel = None
Qwen2_5_VLACausalLMOutputWithPast = None
Qwen2_5_VLMoEModel = None
from .utils import (
get_wallx_normal_text,
preprocesser_call,
@@ -95,7 +116,7 @@ from .utils import (
replace_action_token,
)
logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)
class SinusoidalPosEmb(nn.Module):
@@ -262,7 +283,13 @@ class ActionHead(nn.Module):
return self.propri_proj(proprioception)
class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Conditional base: when transformers is unavailable the class still parses
# (inheriting from nn.Module) but cannot be instantiated—require_package in
# WallXPolicy.__init__ gives the user a clear error before that happens.
_Qwen2_5_VLForAction_Base = Qwen2_5_VLForConditionalGeneration if _wallx_deps_available else nn.Module
class Qwen2_5_VLMoEForAction(_Qwen2_5_VLForAction_Base):
"""
Qwen2.5 Vision-Language Mixture of Experts model for action processing.
@@ -1717,6 +1744,10 @@ class WallXPolicy(PreTrainedPolicy):
name = "wall_x"
def __init__(self, config: WallXConfig, **kwargs):
require_package("transformers", extra="wallx")
require_package("peft", extra="wallx")
require_package("torchdiffeq", extra="wallx")
require_package("qwen-vl-utils", extra="wallx", import_name="qwen_vl_utils")
super().__init__(config)
config.validate_features()
self.config = config
+8 -2
View File
@@ -25,10 +25,16 @@ import random
import re
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING, Any
import torch
from transformers import BatchFeature
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import BatchFeature
else:
BatchFeature = None
from lerobot.utils.constants import OBS_IMAGES
+2
View File
@@ -1,4 +1,5 @@
from .configuration_xvla import XVLAConfig
from .modeling_xvla import XVLAPolicy
from .processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageNetNormalizeProcessorStep,
@@ -7,6 +8,7 @@ from .processor_xvla import (
__all__ = [
"XVLAConfig",
"XVLAPolicy",
"XVLAAddDomainIdProcessorStep",
"XVLAImageNetNormalizeProcessorStep",
"XVLAImageToFloatProcessorStep",
+11 -6
View File
@@ -23,10 +23,7 @@ import logging
import os
from collections import deque
from pathlib import Path
from lerobot.utils.import_utils import require_package
require_package("transformers", extra="xvla")
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
@@ -34,15 +31,22 @@ from torch import Tensor, nn
from lerobot.configs import PreTrainedConfig
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
from ..pretrained import PreTrainedPolicy, T
from ..utils import populate_queues
from .action_hub import build_action_space
from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .soft_transformer import SoftPromptedTransformer
# Florence2 config and modeling depend on transformers
if TYPE_CHECKING or _transformers_available:
from .configuration_florence2 import Florence2Config
from .modeling_florence2 import Florence2ForConditionalGeneration
else:
Florence2Config = None
Florence2ForConditionalGeneration = None
class XVLAModel(nn.Module):
"""
@@ -278,6 +282,7 @@ class XVLAPolicy(PreTrainedPolicy):
name = "xvla"
def __init__(self, config: XVLAConfig, **kwargs):
require_package("transformers", extra="xvla")
super().__init__(config)
config.validate_features()
florence_config = config.get_florence_config()
+2
View File
@@ -56,6 +56,7 @@ from .hil_processor import (
RewardClassifierProcessorStep,
TimeLimitProcessorStep,
)
from .newline_task_processor import NewLineTaskProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
from .observation_processor import VanillaObservationProcessorStep
from .pipeline import (
@@ -119,6 +120,7 @@ __all__ = [
"RelativeActionsProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NewLineTaskProcessorStep",
"NormalizerProcessorStep",
"Numpy2TorchActionProcessorStep",
"ObservationProcessorStep",
@@ -0,0 +1,59 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.configs import PipelineFeatureType, PolicyFeature
from .pipeline import ComplementaryDataProcessorStep, ProcessorStepRegistry
# NOTE: The registry name "smolvla_new_line_processor" is kept for backward compatibility
# with serialized processor configs that reference this name.
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class NewLineTaskProcessorStep(ComplementaryDataProcessorStep):
"""
A processor step that ensures the 'task' description ends with a newline character.
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
newline at the end of the prompt. It handles both single string tasks and lists
of string tasks.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+3 -1
View File
@@ -20,9 +20,11 @@ from typing import Any
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
from lerobot.types import PolicyAction, RobotAction
from lerobot.utils.constants import ACTION
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register("robot_action_to_policy_action_processor")
+16 -2
View File
@@ -1,6 +1,20 @@
# Reinforcement learning modules.
# Requires: lerobot[hilserl]
"""
Reinforcement learning modules.
Requires: ``pip install 'lerobot[hilserl]'``
Available modules (import directly)::
from lerobot.rl.actor import ...
from lerobot.rl.learner import ...
from lerobot.rl.learner_service import ...
from lerobot.rl.buffer import ...
from lerobot.rl.eval_policy import ...
from lerobot.rl.gym_manipulator import ...
"""
from lerobot.utils.import_utils import require_package
require_package("grpcio", extra="hilserl", import_name="grpc")
__all__: list[str] = []
+2 -2
View File
@@ -66,6 +66,7 @@ from lerobot.common.train_utils import (
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset
@@ -99,7 +100,6 @@ from lerobot.utils.utils import (
from .buffer import ReplayBuffer, concatenate_batch_transitions
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .process import ProcessSignalHandler
from .wandb_utils import WandBLogger
@parser.wrap()
@@ -151,7 +151,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
# Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project:
from .wandb_utils import WandBLogger
from lerobot.common.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
+2 -4
View File
@@ -38,10 +38,8 @@ import numpy as np
from PIL import Image
from lerobot.cameras import ColorMode
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.cameras.opencv.camera_opencv import OpenCVCamera
from lerobot.cameras.realsense import RealSenseCameraConfig
from lerobot.cameras.realsense.camera_realsense import RealSenseCamera
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -142,7 +142,7 @@ from lerobot.teleoperators import ( # noqa: F401
so_leader,
unitree_g1,
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.teleoperators.keyboard import KeyboardTeleop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
+1 -1
View File
@@ -35,13 +35,13 @@ from lerobot.common.train_utils import (
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
+11 -2
View File
@@ -1,2 +1,11 @@
# gRPC transport layer for async inference.
# Requires: lerobot[grpcio-dep]
"""
gRPC transport layer for async inference.
Requires: ``pip install 'lerobot[grpcio-dep]'``
Available modules (import directly)::
from lerobot.transport.utils import ...
"""
__all__: list[str] = []