mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
more fixes
This commit is contained in:
@@ -6,4 +6,3 @@ lerobot-eval \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--seed=142
|
||||
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
import json_numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"2toINF/X-VLA-WidowX",
|
||||
trust_remote_code=True
|
||||
)
|
||||
model = AutoModel.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", trust_remote_code=True)
|
||||
|
||||
|
||||
# append 3 random image to a list
|
||||
def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images = []
|
||||
@@ -20,6 +18,7 @@ def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images.append(img)
|
||||
return images
|
||||
|
||||
|
||||
# Example:
|
||||
images = make_random_pil_images()
|
||||
language_instruction = "This is a random image"
|
||||
@@ -29,23 +28,27 @@ if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
||||
raise ValueError("Processor did not return the expected keys.")
|
||||
|
||||
proprio = torch.randn(1, 20)
|
||||
domain_id = torch.tensor([int(0)], dtype=torch.long)
|
||||
domain_id = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
# Align to model's device/dtype
|
||||
device = model.device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
|
||||
def to_model(t: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.as_tensor(t)
|
||||
# cast floats to model dtype, keep integral/bool as-is
|
||||
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
||||
|
||||
|
||||
inputs = {k: to_model(v) for k, v in inputs.items()}
|
||||
inputs.update({
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": to_model(proprio),
|
||||
"domain_id": domain_id.to(device),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# Inference
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.xvla.utils import Rotate6D_to_AxisAngle
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
observation_height: int = 360
|
||||
observation_width: int = 360
|
||||
@@ -16,15 +17,19 @@ OBS = {
|
||||
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
||||
"task": "put the object in the box",
|
||||
}
|
||||
|
||||
|
||||
def fake_rgb(H, W):
|
||||
img = torch.randint(0, 255, (H, W, 3), dtype=torch.uint8).numpy()
|
||||
return img
|
||||
|
||||
|
||||
OBS[f"{OBS_IMAGES}.image1"] = fake_rgb(observation_height, observation_width)
|
||||
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
||||
|
||||
# observation = preprocessor(OBS)
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("2toINF/X-VLA-WidowX", num_views=2, trust_remote_code=True)
|
||||
inputs = processor([OBS[f"{OBS_IMAGES}.image1"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
breakpoint()
|
||||
@@ -53,6 +58,6 @@ observation = preprocessor(OBS)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
target_eef = action[:, :3].to("cpu").numpy()
|
||||
target_axis = Rotate6D_to_AxisAngle(action[:, 3:9].to("cpu").numpy())
|
||||
target_axis = rotate6d_to_axis_angle(action[:, 3:9].to("cpu").numpy())
|
||||
target_act = action[:, 9:10].to("cpu").numpy()
|
||||
final_action = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
import os
|
||||
|
||||
cfg = make_policy_config("xvla")
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
@@ -16,7 +18,9 @@ for name, param in policy.state_dict().items():
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model")
|
||||
cache_dir = snapshot_download(
|
||||
repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model"
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors"))
|
||||
# policy.load_state_dict(state_dict)
|
||||
# 3. Add "model." prefix to every key
|
||||
@@ -36,16 +40,18 @@ print()
|
||||
print("unexpected keys:", unexpected)
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
@@ -61,6 +67,8 @@ OBS = {
|
||||
|
||||
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
|
||||
|
||||
def fake_rgb(H, W):
|
||||
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
|
||||
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
|
||||
@@ -101,11 +109,13 @@ from xvla.models.processing_xvla import XVLAProcessor
|
||||
|
||||
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
|
||||
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
domain_id = torch.tensor([int(3)], dtype=torch.long)
|
||||
inputs.update({
|
||||
domain_id = torch.tensor([3], dtype=torch.long)
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": OBS[OBS_STATE].to("cuda"),
|
||||
"domain_id": domain_id.to("cuda"),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
@@ -145,16 +155,19 @@ print("max diff:", np.max(np.abs(action - action_1)))
|
||||
print("mean diff:", np.mean(np.abs(action - action_1)))
|
||||
|
||||
|
||||
from xvla.models.processor_xvla import XVLAProcessor
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
from xvla.models.configuration_xvla import XVLAConfig
|
||||
import torch
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from lerobot.policies.factory import make_policy
|
||||
from xvla.models.configuration_xvla import XVLAConfig
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
from xvla.models.processor_xvla import XVLAProcessor
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy
|
||||
|
||||
cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
model.eval()
|
||||
@@ -166,6 +179,7 @@ torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
|
||||
def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images = []
|
||||
for _ in range(num_images):
|
||||
@@ -175,6 +189,7 @@ def make_random_pil_images(num_images=3, H=480, W=640):
|
||||
images.append(img)
|
||||
return images
|
||||
|
||||
|
||||
# Example:
|
||||
images = make_random_pil_images()
|
||||
language_instruction = "This is a random image"
|
||||
@@ -184,23 +199,27 @@ if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
|
||||
raise ValueError("Processor did not return the expected keys.")
|
||||
|
||||
proprio = torch.randn(1, 20)
|
||||
domain_id = torch.tensor([int(0)], dtype=torch.long)
|
||||
domain_id = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
# Align to model's device/dtype
|
||||
device = model.device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
|
||||
def to_model(t: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
t = torch.as_tensor(t)
|
||||
# cast floats to model dtype, keep integral/bool as-is
|
||||
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
|
||||
|
||||
|
||||
inputs = {k: to_model(v) for k, v in inputs.items()}
|
||||
inputs.update({
|
||||
inputs.update(
|
||||
{
|
||||
"proprio": to_model(proprio),
|
||||
"domain_id": domain_id.to(device),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# Inference
|
||||
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
|
||||
@@ -29,7 +29,9 @@ from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
from robosuite.utils.transform_utils import quat2axisangle
|
||||
from lerobot.policies.xvla.utils import Mat_to_Rotate6D
|
||||
|
||||
from lerobot.policies.xvla.utils import mat_to_rotate6d
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
@@ -227,13 +229,15 @@ class LiberoEnv(gym.Env):
|
||||
# TODO: jadechoghari, this is an ugly quick workaround for XVLA states.
|
||||
# we will open a new PR to handle this in a preprocessor.
|
||||
elif self.action_type == "abs":
|
||||
robo_ori = Mat_to_Rotate6D(self._env.robots[0].controller.ee_ori_mat)
|
||||
robo_ori = mat_to_rotate6d(self._env.robots[0].controller.ee_ori_mat)
|
||||
robo_pos = self._env.robots[0].controller.ee_pos
|
||||
proprio = np.concatenate([robo_pos, robo_ori, np.array([0.0])], axis=-1)
|
||||
state = np.concatenate([proprio, np.zeros_like(proprio)], axis=-1)
|
||||
else:
|
||||
raise NotImplementedError(f"The action type '{self.action_type}' is not supported in LiberoEnv. "
|
||||
"Please switch to an action type (e.g. 'rel', 'abs').")
|
||||
raise NotImplementedError(
|
||||
f"The action type '{self.action_type}' is not supported in LiberoEnv. "
|
||||
"Please switch to an action type (e.g. 'rel', 'abs')."
|
||||
)
|
||||
agent_pos = state
|
||||
if self.obs_type == "pixels":
|
||||
return {"pixels": images.copy()}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
make_xvla_pre_post_processors,
|
||||
XVLAImageScaleProcessorStep,
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
XVLAImageScaleProcessorStep,
|
||||
XVLARotation6DToAxisAngleProcessorStep,
|
||||
make_xvla_pre_post_processors,
|
||||
)
|
||||
@@ -187,7 +187,7 @@ class Florence2LanguageConfig(PretrainedConfig):
|
||||
>>> configuration = Florence2LanguageConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = Florence2LangaugeModel(configuration)
|
||||
>>> model = Florence2LanguageModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
|
||||
@@ -28,8 +28,6 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_florence2 import Florence2VisionConfig
|
||||
from .configuration_florence2 import Florence2LanguageConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("xvla")
|
||||
|
||||
@@ -496,8 +496,8 @@ class DaViT(nn.Module):
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1.
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
enable_checkpoint (bool): If True, enable checkpointing. Default: False.
|
||||
conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True.
|
||||
conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True.
|
||||
conv_at_attn (bool): If True, perform depthwise convolution before attention layer. Default: True.
|
||||
conv_at_ffn (bool): If True, perform depthwise convolution before ffn layer. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -892,7 +892,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
@@ -34,7 +35,6 @@ from .configuration_xvla import XVLAConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .transformer import SoftPromptedTransformer
|
||||
|
||||
import os
|
||||
|
||||
class XVLAModel(nn.Module):
|
||||
"""
|
||||
@@ -98,7 +98,6 @@ class XVLAModel(nn.Module):
|
||||
# target_size = (224, 224)
|
||||
# flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
|
||||
|
||||
|
||||
num_valid = int(flat_mask.sum().item())
|
||||
if num_valid == 0:
|
||||
raise ValueError("At least one image view must be valid per batch.")
|
||||
@@ -347,7 +346,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
cls,
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: "PreTrainedConfig" | None = None,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
@@ -364,6 +363,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
- skip list for layers that should remain randomly initialized
|
||||
"""
|
||||
import safetensors.torch
|
||||
|
||||
# --- Step 1: Load config ---
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
@@ -386,6 +386,9 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
model_file = os.path.join(model_id, "model.safetensors")
|
||||
else:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="model.safetensors",
|
||||
@@ -398,9 +401,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"model.safetensors not found on the Hub at {model_id}"
|
||||
) from e
|
||||
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||
|
||||
print(f"Loading checkpoint from {model_file}")
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.xvla.utils import Rotate6D_to_AxisAngle
|
||||
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -38,6 +38,7 @@ from lerobot.processor.converters import policy_action_to_transition, transition
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_xvla_pre_post_processors(
|
||||
config: XVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
@@ -231,7 +232,7 @@ class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
|
||||
target_act = action_np[:, 9:10] # (B, 1)
|
||||
|
||||
# Convert 6D rotation to axis-angle
|
||||
target_axis = Rotate6D_to_AxisAngle(rotation_6d) # (B, 3)
|
||||
target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
|
||||
|
||||
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
|
||||
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
import robosuite.utils.transform_utils as T
|
||||
def Rotate6D_to_AxisAngle(r6d):
|
||||
|
||||
|
||||
def rotate6d_to_axis_angle(r6d):
|
||||
"""
|
||||
r6d: np.ndarray, shape (N, 6)
|
||||
return: np.ndarray, shape (N, 3), axis-angle vectors
|
||||
@@ -39,7 +41,8 @@ def Rotate6D_to_AxisAngle(r6d):
|
||||
|
||||
return axis_angle_array
|
||||
|
||||
def Mat_to_Rotate6D(abs_action):
|
||||
|
||||
def mat_to_rotate6d(abs_action):
|
||||
if len(abs_action.shape) == 2:
|
||||
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
|
||||
elif len(abs_action.shape) == 3:
|
||||
|
||||
@@ -325,7 +325,9 @@ def load_state_dict_with_missing_key_handling(
|
||||
Returns:
|
||||
List of problematic missing keys that weren't in the whitelist.
|
||||
"""
|
||||
state_dict['model.vlm.language_model.model.encoder.embed_tokens.weight'] = state_dict['model.vlm.language_model.model.shared.weight'].clone()
|
||||
state_dict["model.vlm.language_model.model.encoder.embed_tokens.weight"] = state_dict[
|
||||
"model.vlm.language_model.model.shared.weight"
|
||||
].clone()
|
||||
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
|
||||
load_result = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
@@ -25,9 +25,10 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.factory import IMAGENET_STATS
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.datasets.factory import IMAGENET_STATS
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||
@@ -303,7 +304,11 @@ class _NormalizationMixin:
|
||||
ValueError: If an unsupported normalization mode is encountered.
|
||||
"""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats and norm_mode != NormalizationMode.IMAGENET:
|
||||
if (
|
||||
norm_mode == NormalizationMode.IDENTITY
|
||||
or key not in self._tensor_stats
|
||||
and norm_mode != NormalizationMode.IMAGENET
|
||||
):
|
||||
return tensor
|
||||
if norm_mode not in (
|
||||
NormalizationMode.MEAN_STD,
|
||||
|
||||
@@ -55,6 +55,7 @@ from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, Transitio
|
||||
TInput = TypeVar("TInput")
|
||||
TOutput = TypeVar("TOutput")
|
||||
|
||||
|
||||
class ProcessorStepRegistry:
|
||||
"""A registry for ProcessorStep classes to allow instantiation from a string name.
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ Note that in both examples, the repo/folder should contain at least `config.json
|
||||
|
||||
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
|
||||
"""
|
||||
|
||||
import concurrent.futures as cf
|
||||
import json
|
||||
import logging
|
||||
@@ -89,6 +90,7 @@ from lerobot.utils.utils import (
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -153,7 +155,6 @@ def rollout(
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
@@ -164,17 +165,13 @@ def rollout(
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
# Preprocess observation (includes image scaling and domain_id addition)
|
||||
observation = preprocessor(observation)
|
||||
# Policy inference
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Postprocess action (includes rotation conversion and device transfer to CPU)
|
||||
action = postprocessor(action)
|
||||
# Convert to numpy
|
||||
action_numpy: np.ndarray = action.numpy()
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
@@ -500,6 +497,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Making policy.")
|
||||
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
|
||||
Reference in New Issue
Block a user