Compare commits

...

8 Commits

Author SHA1 Message Date
Jade Choghari 62d23b0986 add for rest of policies 2026-02-27 16:32:33 +01:00
Jade Choghari a6a2f3662a Merge branch 'main' into speedup-pi05-launch 2026-02-27 18:12:21 +03:00
Khalil Meftah c085531b17 fix: add missing openarm_mini import to CLI scripts (#3028) 2026-02-27 15:46:31 +01:00
Steven Palma c7c6205332 chore(scripts): no spam log when no action (#3042) 2026-02-27 15:26:56 +01:00
Michio Sun 4e54be1334 fix(datasets): skip warning when MultiLeRobotDataset features are identical (#3019)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 17:42:22 +01:00
Damien LaRocque fde9d08281 feat(async_inference) Enable plugins with async inference (#2425)
* feat(async-inference) Try using async inference server with plugins

* Fix import

* Fix import error in Robot Client

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 14:41:32 +01:00
Khalil Meftah 46044fed75 Fix: remove device_map from SmolVLA model loading (#3029)
* Fix SmolVLA meta tensor error by removing device_map

- Remove device_map parameter from VLM model loading
- Change torch_dtype from string to torch.bfloat16
- Add explicit .to(device) calls after initialization

This resolves NotImplementedError when training SmolVLA policy.
Fixes meta tensor copy issue in factory.py:418.

* fix: remove manual device movement logic and fix dtype handling

---------

Co-authored-by: Highsky7 <albert31115@gmail.com>
2026-02-26 13:28:46 +01:00
Jeremiah Coholich 49444652c6 speedup pi-05 modeling loading by 72s 2026-02-20 15:41:44 -05:00
10 changed files with 49 additions and 24 deletions
+7 -10
View File
@@ -49,23 +49,18 @@ import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.robots import ( # noqa: F401 from lerobot.robots import (
Robot, RobotConfig, # noqa: F401
RobotConfig,
bi_so_follower,
koch_follower,
make_robot_from_config, make_robot_from_config,
omx_follower,
so_follower,
) )
from lerobot.transport import ( from lerobot.transport import (
services_pb2, # type: ignore services_pb2, # type: ignore
services_pb2_grpc, # type: ignore services_pb2_grpc, # type: ignore
) )
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
from lerobot.utils.import_utils import register_third_party_plugins
from .configs import RobotClientConfig from .configs import RobotClientConfig
from .constants import SUPPORTED_ROBOTS
from .helpers import ( from .helpers import (
Action, Action,
FPSTracker, FPSTracker,
@@ -485,8 +480,9 @@ class RobotClient:
def async_client(cfg: RobotClientConfig): def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS: # TODO: Assert if checking robot support is still needed with the plugin system
raise ValueError(f"Robot {cfg.robot.type} not yet supported!") # if cfg.robot.type not in SUPPORTED_ROBOTS:
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg) client = RobotClient(cfg)
@@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig):
if __name__ == "__main__": if __name__ == "__main__":
register_third_party_plugins()
async_client() # run the client async_client() # run the client
+1
View File
@@ -1771,6 +1771,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
) )
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features) extra_keys = set(ds.features).difference(intersection_features)
if extra_keys:
logging.warning( logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets." "other datasets."
+7
View File
@@ -995,6 +995,13 @@ class PI0Policy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs) model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
@@ -967,6 +967,13 @@ class PI05Policy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs) model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
@@ -895,6 +895,13 @@ class PI0FastPolicy(PreTrainedPolicy):
# Initialize model without loading weights # Initialize model without loading weights
# Check if dataset_stats were provided in kwargs # Check if dataset_stats were provided in kwargs
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs) model = cls(config, **kwargs)
# Now manually load and remap the state dict # Now manually load and remap the state dict
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
print(f"Loading {model_id} weights ...") print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained( self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id, model_id,
device_map=device,
torch_dtype="bfloat16", torch_dtype="bfloat16",
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
) )
+1
View File
@@ -56,6 +56,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config, make_teleoperator_from_config,
omx_leader, omx_leader,
openarm_leader, openarm_leader,
openarm_mini,
so_leader, so_leader,
unitree_g1, unitree_g1,
) )
@@ -61,6 +61,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config, make_teleoperator_from_config,
omx_leader, omx_leader,
openarm_leader, openarm_leader,
openarm_mini,
so_leader, so_leader,
) )
from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.robot_utils import precise_sleep
+5 -1
View File
@@ -125,6 +125,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config, make_teleoperator_from_config,
omx_leader, omx_leader,
openarm_leader, openarm_leader,
openarm_mini,
reachy2_teleoperator, reachy2_teleoperator,
so_leader, so_leader,
unitree_g1, unitree_g1,
@@ -333,6 +334,7 @@ def record_loop(
preprocessor.reset() preprocessor.reset()
postprocessor.reset() postprocessor.reset()
no_action_count = 0
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
while timestamp < control_time_s: while timestamp < control_time_s:
@@ -380,7 +382,9 @@ def record_loop(
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
act_processed_teleop = teleop_action_processor((act, obs)) act_processed_teleop = teleop_action_processor((act, obs))
else: else:
logging.info( no_action_count += 1
if no_action_count == 1 or no_action_count % 10 == 0:
logging.warning(
"No policy or teleoperator provided, skipping action generation. " "No policy or teleoperator provided, skipping action generation. "
"This is likely to happen when resetting the environment without a teleop device. " "This is likely to happen when resetting the environment without a teleop device. "
"The robot won't be at its rest position at the start of the next episode." "The robot won't be at its rest position at the start of the next episode."
@@ -94,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401
make_teleoperator_from_config, make_teleoperator_from_config,
omx_leader, omx_leader,
openarm_leader, openarm_leader,
openarm_mini,
reachy2_teleoperator, reachy2_teleoperator,
so_leader, so_leader,
unitree_g1, unitree_g1,