Compare commits

..

1 Commits

Author SHA1 Message Date
AdilZouitine 7ad93bdbf1 fix caching and dataset stats is optional 2025-04-09 13:20:51 +00:00
9 changed files with 120 additions and 124 deletions
+4 -3
View File
@@ -171,6 +171,7 @@ class VideoRecordConfig:
class WrapperConfig: class WrapperConfig:
"""Configuration for environment wrappers.""" """Configuration for environment wrappers."""
delta_action: float | None = None
joint_masking_action_space: list[bool] | None = None joint_masking_action_space: list[bool] | None = None
@@ -190,6 +191,7 @@ class EnvWrapperConfig:
"""Configuration for environment wrappers.""" """Configuration for environment wrappers."""
display_cameras: bool = False display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False add_ee_pose_to_observation: bool = False
@@ -201,9 +203,8 @@ class EnvWrapperConfig:
joint_masking_action_space: Optional[Any] = None joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False use_gripper: bool = False
gripper_quantization_threshold: float | None = None gripper_quantization_threshold: float = 0.8
gripper_penalty: float = 0.0 gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False open_gripper_on_reset: bool = False
@@ -120,7 +120,7 @@ class SACConfig(PreTrainedConfig):
} }
) )
dataset_stats: dict[str, dict[str, list[float]]] = field( dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": { "observation.image": {
"mean": [0.485, 0.456, 0.406], "mean": [0.485, 0.456, 0.406],
+23 -19
View File
@@ -65,16 +65,21 @@ class SACPolicy(
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# HACK: This is hacky and should be removed if config.dataset_stats is not None:
dataset_stats = dataset_stats or output_normalization_params output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats # HACK: This is hacky and should be removed
) dataset_stats = dataset_stats or output_normalization_params
self.unnormalize_outputs = Unnormalize( self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats config.output_features, config.normalization_mapping, dataset_stats
) )
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
else:
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
# NOTE: For images the encoder should be shared between the actor and critic # NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder: if config.shared_encoder:
@@ -129,7 +134,7 @@ class SACPolicy(
encoder=encoder_critic, encoder=encoder_critic,
input_dim=encoder_critic.output_dim, input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions, output_dim=config.num_discrete_actions,
softmax_temperature=.15, softmax_temperature=1.0,
**asdict(config.grasp_critic_network_kwargs), **asdict(config.grasp_critic_network_kwargs),
) )
@@ -138,14 +143,14 @@ class SACPolicy(
encoder=encoder_critic, encoder=encoder_critic,
input_dim=encoder_critic.output_dim, input_dim=encoder_critic.output_dim,
output_dim=config.num_discrete_actions, output_dim=config.num_discrete_actions,
softmax_temperature=0.15, softmax_temperature=1.0,
**asdict(config.grasp_critic_network_kwargs), **asdict(config.grasp_critic_network_kwargs),
) )
self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict()) self.grasp_critic_target.load_state_dict(self.grasp_critic.state_dict())
# self.grasp_critic = torch.compile(self.grasp_critic) self.grasp_critic = torch.compile(self.grasp_critic)
# self.grasp_critic_target = torch.compile(self.grasp_critic_target) self.grasp_critic_target = torch.compile(self.grasp_critic_target)
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
@@ -192,7 +197,7 @@ class SACPolicy(
# We cached the encoder output to avoid recomputing it # We cached the encoder output to avoid recomputing it
observations_features = None observations_features = None
if self.shared_encoder: if self.shared_encoder:
observations_features = self.actor.encoder.get_image_features(batch) observations_features = self.actor.encoder.get_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features) actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -428,7 +433,6 @@ class SACPolicy(
actions_discrete = torch.round(actions_discrete) actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long() actions_discrete = actions_discrete.long()
gripper_penalties: Tensor | None = None
if complementary_info is not None: if complementary_info is not None:
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty") gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
@@ -569,7 +573,7 @@ class SACObservationEncoder(nn.Module):
feat = [] feat = []
obs_dict = self.input_normalization(obs_dict) obs_dict = self.input_normalization(obs_dict)
if len(self.all_image_keys) > 0 and vision_encoder_cache is None: if len(self.all_image_keys) > 0 and vision_encoder_cache is None:
vision_encoder_cache = self.get_image_features(obs_dict) vision_encoder_cache = self.get_image_features(obs_dict, normalize=False)
if vision_encoder_cache is not None: if vision_encoder_cache is not None:
feat.append(vision_encoder_cache) feat.append(vision_encoder_cache)
@@ -584,8 +588,10 @@ class SACObservationEncoder(nn.Module):
return features return features
def get_image_features(self, batch: dict[str, Tensor]) -> torch.Tensor: def get_image_features(self, batch: dict[str, Tensor], normalize: bool = True) -> torch.Tensor:
# [N*B, C, H, W] # [N*B, C, H, W]
if normalize:
batch = self.input_normalization(batch)
if len(self.all_image_keys) > 0: if len(self.all_image_keys) > 0:
# Batch all images along the batch dimension, then encode them. # Batch all images along the batch dimension, then encode them.
images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0) images_batched = torch.cat([batch[key] for key in self.all_image_keys], dim=0)
@@ -786,7 +792,6 @@ class GraspCritic(nn.Module):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.output_dim = output_dim self.output_dim = output_dim
self.softmax_temperature = softmax_temperature
self.net = MLP( self.net = MLP(
input_dim=input_dim, input_dim=input_dim,
@@ -798,7 +803,6 @@ class GraspCritic(nn.Module):
) )
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
init_final = 0.05
if init_final is not None: if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
@@ -221,6 +221,7 @@ def record_episode(
events=events, events=events,
policy=policy, policy=policy,
fps=fps, fps=fps,
# record_delta_actions=record_delta_actions,
teleoperate=policy is None, teleoperate=policy is None,
single_task=single_task, single_task=single_task,
) )
@@ -266,6 +267,8 @@ def control_loop(
if teleoperate: if teleoperate:
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
# if record_delta_actions:
# action["action"] = action["action"] - current_joint_positions
else: else:
observation = robot.capture_observation() observation = robot.capture_observation()
+20 -10
View File
@@ -250,18 +250,28 @@ def act_with_policy(
logging.info("[ACTOR] Shutting down act_with_policy") logging.info("[ACTOR] Shutting down act_with_policy")
return return
# Time policy inference and check if it meets FPS requirement if interaction_step >= cfg.policy.online_step_before_learning:
with TimerManager( # Time policy inference and check if it meets FPS requirement
elapsed_time_list=list_policy_time, with TimerManager(
label="Policy inference time", elapsed_time_list=list_policy_time,
log=False, label="Policy inference time",
) as timer: # noqa: F841 log=False,
action = policy.select_action(batch=obs) ) as timer: # noqa: F841
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy()) next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
)
sum_reward_episode += float(reward) sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate # Increment total steps counter for intervention rate
+8 -2
View File
@@ -79,7 +79,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)): elif isinstance(val, (int, float, bool)):
transition["complementary_info"][key] = torch.tensor( transition["complementary_info"][key] = torch.tensor(
val, device=device val, device=device, non_blocking=non_blocking
) )
else: else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
@@ -269,7 +269,7 @@ class ReplayBuffer:
self.complementary_info[key] = torch.empty( self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device (self.capacity, *value_shape), device=self.storage_device
) )
elif isinstance(value, (int, float, bool)): elif isinstance(value, (int, float)):
# Handle scalar values similar to reward # Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else: else:
@@ -505,6 +505,7 @@ class ReplayBuffer:
state_keys: Optional[Sequence[str]] = None, state_keys: Optional[Sequence[str]] = None,
capacity: Optional[int] = None, capacity: Optional[int] = None,
action_mask: Optional[Sequence[int]] = None, action_mask: Optional[Sequence[int]] = None,
action_delta: Optional[float] = None,
image_augmentation_function: Optional[Callable] = None, image_augmentation_function: Optional[Callable] = None,
use_drq: bool = True, use_drq: bool = True,
storage_device: str = "cpu", storage_device: str = "cpu",
@@ -519,6 +520,7 @@ class ReplayBuffer:
state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`. state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`.
capacity (Optional[int]): Buffer capacity. If None, uses dataset length. capacity (Optional[int]): Buffer capacity. If None, uses dataset length.
action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep. action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep.
action_delta (Optional[float]): Factor to divide actions by.
image_augmentation_function (Optional[Callable]): Function for image augmentation. image_augmentation_function (Optional[Callable]): Function for image augmentation.
If None, uses default random shift with pad=4. If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling. use_drq (bool): Whether to use DrQ image augmentation when sampling.
@@ -563,6 +565,8 @@ class ReplayBuffer:
else: else:
first_action = first_action[:, action_mask] first_action = first_action[:, action_mask]
if action_delta is not None:
first_action = first_action / action_delta
# Get complementary info if available # Get complementary info if available
first_complementary_info = None first_complementary_info = None
@@ -594,6 +598,8 @@ class ReplayBuffer:
else: else:
action = action[:, action_mask] action = action[:, action_mask]
if action_delta is not None:
action = action / action_delta
replay_buffer.add( replay_buffer.add(
state=data["state"], state=data["state"],
@@ -258,25 +258,25 @@ class GamepadController(InputController):
elif event.button == 0: elif event.button == 0:
self.episode_end_status = "rerecord_episode" self.episode_end_status = "rerecord_episode"
# LT button for closing gripper # RB button (6) for opening gripper
elif event.button == 6: elif event.button == 6:
self.close_gripper_command = True
# RB button for opening gripper
elif event.button == 7:
self.open_gripper_command = True self.open_gripper_command = True
# LT button (7) for closing gripper
elif event.button == 7:
self.close_gripper_command = True
# Reset episode status on button release # Reset episode status on button release
elif event.type == pygame.JOYBUTTONUP: elif event.type == pygame.JOYBUTTONUP:
if event.button in [0, 2, 3]: if event.button in [0, 2, 3]:
self.episode_end_status = None self.episode_end_status = None
if event.button == 6: elif event.button == 6:
self.close_gripper_command = False
if event.button == 7:
self.open_gripper_command = False self.open_gripper_command = False
elif event.button == 7:
self.close_gripper_command = False
# Check for RB button (typically button 5) for intervention flag # Check for RB button (typically button 5) for intervention flag
if self.joystick.get_button(5): if self.joystick.get_button(5):
self.intervention_flag = True self.intervention_flag = True
+44 -71
View File
@@ -42,6 +42,7 @@ class HILSerlRobotEnv(gym.Env):
self, self,
robot, robot,
use_delta_action_space: bool = True, use_delta_action_space: bool = True,
delta: float | None = None,
display_cameras: bool = False, display_cameras: bool = False,
): ):
""" """
@@ -54,6 +55,8 @@ class HILSerlRobotEnv(gym.Env):
robot: The robot interface object used to connect and interact with the physical robot. robot: The robot interface object used to connect and interact with the physical robot.
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
joint positions are used. joint positions are used.
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
0 and 1 when using a delta action space.
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
""" """
super().__init__() super().__init__()
@@ -71,6 +74,7 @@ class HILSerlRobotEnv(gym.Env):
self.current_step = 0 self.current_step = 0
self.episode_data = None self.episode_data = None
self.delta = delta
self.use_delta_action_space = use_delta_action_space self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position") self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
@@ -370,7 +374,7 @@ class RewardWrapper(gym.Wrapper):
self.device = device self.device = device
def step(self, action): def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action) observation, _, terminated, truncated, info = self.env.step(action)
images = [ images = [
observation[key].to(self.device, non_blocking=self.device.type == "cuda") observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation for key in observation
@@ -378,17 +382,15 @@ class RewardWrapper(gym.Wrapper):
] ]
start_time = time.perf_counter() start_time = time.perf_counter()
with torch.inference_mode(): with torch.inference_mode():
success = ( reward = (
self.reward_classifier.predict_reward(images, threshold=0.8) self.reward_classifier.predict_reward(images, threshold=0.8)
if self.reward_classifier is not None if self.reward_classifier is not None
else 0.0 else 0.0
) )
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
if success == 1.0: if reward == 1.0:
terminated = True terminated = True
reward = 1.0
return observation, reward, terminated, truncated, info return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
@@ -553,9 +555,6 @@ class ImageCropResizeWrapper(gym.Wrapper):
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1] # TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
obs[k] = obs[k].clamp(0.0, 1.0) obs[k] = obs[k].clamp(0.0, 1.0)
# import cv2
# cv2.imwrite(f"tmp_img/{k}.jpg", cv2.cvtColor(obs[k].squeeze(0).permute(1,2,0).cpu().numpy()*255, cv2.COLOR_RGB2BGR))
# Check for NaNs after processing # Check for NaNs after processing
if torch.isnan(obs[k]).any(): if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} after crop and resize") logging.error(f"NaN values detected in observation {k} after crop and resize")
@@ -721,31 +720,19 @@ class ResetWrapper(gym.Wrapper):
env: HILSerlRobotEnv, env: HILSerlRobotEnv,
reset_pose: np.ndarray | None = None, reset_pose: np.ndarray | None = None,
reset_time_s: float = 5, reset_time_s: float = 5,
open_gripper_on_reset: bool = False
): ):
super().__init__(env) super().__init__(env)
self.reset_time_s = reset_time_s self.reset_time_s = reset_time_s
self.reset_pose = reset_pose self.reset_pose = reset_pose
self.robot = self.unwrapped.robot self.robot = self.unwrapped.robot
self.open_gripper_on_reset = open_gripper_on_reset
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
if self.reset_pose is not None: if self.reset_pose is not None:
start_time = time.perf_counter() start_time = time.perf_counter()
log_say("Reset the environment.", play_sounds=True) log_say("Reset the environment.", play_sounds=True)
reset_follower_position(self.robot, self.reset_pose) reset_follower_position(self.robot, self.reset_pose)
busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
log_say("Reset the environment done.", play_sounds=True) log_say("Reset the environment done.", play_sounds=True)
if self.open_gripper_on_reset:
current_joint_pos = self.robot.follower_arms["main"].read("Present_Position")
current_joint_pos[-1] = MAX_GRIPPER_COMMAND
self.robot.send_action(torch.from_numpy(current_joint_pos))
busy_wait(0.1)
current_joint_pos[-1] = 0.0
self.robot.send_action(torch.from_numpy(current_joint_pos))
busy_wait(0.2)
else: else:
log_say( log_say(
f"Manually reset the environment for {self.reset_time_s} seconds.", f"Manually reset the environment for {self.reset_time_s} seconds.",
@@ -775,48 +762,37 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
class GripperPenaltyWrapper(gym.RewardWrapper): class GripperPenaltyWrapper(gym.RewardWrapper):
def __init__(self, env, penalty: float = -0.1, gripper_penalty_in_reward: bool = True): def __init__(self, env, penalty: float = -0.1):
super().__init__(env) super().__init__(env)
self.penalty = penalty self.penalty = penalty
self.gripper_penalty_in_reward = gripper_penalty_in_reward
self.last_gripper_state = None self.last_gripper_state = None
def reward(self, reward, action): def reward(self, reward, action):
gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND gripper_state_normalized = self.last_gripper_state / MAX_GRIPPER_COMMAND
action_normalized = action - 1.0 #action / MAX_GRIPPER_COMMAND if isinstance(action, tuple):
action = action[0]
action_normalized = action[-1] / MAX_GRIPPER_COMMAND
gripper_penalty_bool = (gripper_state_normalized < 0.75 and action_normalized > 0.5) or ( gripper_penalty_bool = (gripper_state_normalized < 0.1 and action_normalized > 0.9) or (
gripper_state_normalized > 0.75 and action_normalized < -0.5 gripper_state_normalized > 0.9 and action_normalized < 0.1
) )
breakpoint()
return reward + self.penalty * int(gripper_penalty_bool) return reward + self.penalty * gripper_penalty_bool
def step(self, action): def step(self, action):
self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] self.last_gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
if isinstance(action, tuple):
gripper_action = action[0][-1]
else:
gripper_action = action[-1]
obs, reward, terminated, truncated, info = self.env.step(action) obs, reward, terminated, truncated, info = self.env.step(action)
gripper_penalty = self.reward(reward, gripper_action) reward = self.reward(reward, action)
if self.gripper_penalty_in_reward:
reward += gripper_penalty
else:
info["gripper_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
def reset(self, **kwargs): def reset(self, **kwargs):
self.last_gripper_state = None self.last_gripper_state = None
obs, info = super().reset(**kwargs) return super().reset(**kwargs)
if self.gripper_penalty_in_reward:
info["gripper_penalty"] = 0.0
return obs, info
class GripperActionWrapper(gym.ActionWrapper):
class GripperQuantizationWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2): def __init__(self, env, quantization_threshold: float = 0.2):
super().__init__(env) super().__init__(env)
self.quantization_threshold = quantization_threshold self.quantization_threshold = quantization_threshold
@@ -825,18 +801,16 @@ class GripperActionWrapper(gym.ActionWrapper):
is_intervention = False is_intervention = False
if isinstance(action, tuple): if isinstance(action, tuple):
action, is_intervention = action action, is_intervention = action
gripper_command = action[-1] gripper_command = action[-1]
# Quantize gripper command to -1, 0 or 1
if gripper_command < -self.quantization_threshold:
gripper_command = -MAX_GRIPPER_COMMAND
elif gripper_command > self.quantization_threshold:
gripper_command = MAX_GRIPPER_COMMAND
else:
gripper_command = 0.0
# Gripper actions are between 0, 2
# we want to quantize them to -1, 0 or 1
gripper_command = gripper_command - 1.0
if self.quantization_threshold is not None:
# Quantize gripper command to -1, 0 or 1
gripper_command = (
np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0
)
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1] gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND) gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
action[-1] = gripper_action.item() action[-1] = gripper_action.item()
@@ -862,12 +836,10 @@ class EEActionWrapper(gym.ActionWrapper):
] ]
) )
if self.use_gripper: if self.use_gripper:
# gripper actions open at 2.0, and closed at 0.0 action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
min_action_space_bounds = np.concatenate([-action_space_bounds, [0.0]])
max_action_space_bounds = np.concatenate([action_space_bounds, [2.0]])
ee_action_space = gym.spaces.Box( ee_action_space = gym.spaces.Box(
low=min_action_space_bounds, low=-action_space_bounds,
high=max_action_space_bounds, high=action_space_bounds,
shape=(3 + int(self.use_gripper),), shape=(3 + int(self.use_gripper),),
dtype=np.float32, dtype=np.float32,
) )
@@ -1025,11 +997,11 @@ class GamepadControlWrapper(gym.Wrapper):
if self.use_gripper: if self.use_gripper:
gripper_command = self.controller.gripper_command() gripper_command = self.controller.gripper_command()
if gripper_command == "open": if gripper_command == "open":
gamepad_action = np.concatenate([gamepad_action, [2.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [0.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [1.0]]) gamepad_action = np.concatenate([gamepad_action, [1.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [0.0]])
# Check episode ending buttons # Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None # We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
@@ -1169,6 +1141,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env = HILSerlRobotEnv( env = HILSerlRobotEnv(
robot=robot, robot=robot,
display_cameras=cfg.wrapper.display_cameras, display_cameras=cfg.wrapper.display_cameras,
delta=cfg.wrapper.delta_action,
use_delta_action_space=cfg.wrapper.use_relative_joint_positions use_delta_action_space=cfg.wrapper.use_relative_joint_positions
and cfg.wrapper.ee_action_space_params is None, and cfg.wrapper.ee_action_space_params is None,
) )
@@ -1192,11 +1165,10 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) # env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper: if cfg.wrapper.use_gripper:
env = GripperActionWrapper( env = GripperQuantizationWrapper(
env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold
) )
if cfg.wrapper.gripper_penalty is not None: # env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty)
env = GripperPenaltyWrapper(env=env, penalty=cfg.wrapper.gripper_penalty, gripper_penalty_in_reward=cfg.wrapper.gripper_penalty_in_reward)
if cfg.wrapper.ee_action_space_params is not None: if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper( env = EEActionWrapper(
@@ -1204,7 +1176,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
ee_action_space_params=cfg.wrapper.ee_action_space_params, ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper, use_gripper=cfg.wrapper.use_gripper,
) )
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad: if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params) # env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper( env = GamepadControlWrapper(
@@ -1221,7 +1192,6 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
env=env, env=env,
reset_pose=cfg.wrapper.fixed_reset_joint_positions, reset_pose=cfg.wrapper.fixed_reset_joint_positions,
reset_time_s=cfg.wrapper.reset_time_s, reset_time_s=cfg.wrapper.reset_time_s,
open_gripper_on_reset=cfg.wrapper.open_gripper_on_reset
) )
if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None: if cfg.wrapper.ee_action_space_params is None and cfg.wrapper.joint_masking_action_space is not None:
env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space) env = JointMaskingActionSpace(env=env, mask=cfg.wrapper.joint_masking_action_space)
@@ -1371,10 +1341,11 @@ def record_dataset(env, policy, cfg):
dataset.push_to_hub() dataset.push_to_hub()
def replay_episode(env, cfg): def replay_episode(env, repo_id, root=None, episode=0):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) local_files_only = root is not None
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
env.reset() env.reset()
actions = dataset.hf_dataset.select_columns("action") actions = dataset.hf_dataset.select_columns("action")
@@ -1382,7 +1353,7 @@ def replay_episode(env, cfg):
for idx in range(dataset.num_frames): for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
action = actions[idx]["action"] action = actions[idx]["action"][:4]
env.step((action, False)) env.step((action, False))
# env.step((action / env.unwrapped.delta, False)) # env.step((action / env.unwrapped.delta, False))
@@ -1413,7 +1384,9 @@ def main(cfg: EnvConfig):
if cfg.mode == "replay": if cfg.mode == "replay":
replay_episode( replay_episode(
env, env,
cfg=cfg, cfg.replay_repo_id,
root=cfg.dataset_root,
episode=cfg.replay_episode,
) )
exit() exit()
+6 -7
View File
@@ -380,7 +380,6 @@ def add_actor_information_and_train(
for _ in range(utd_ratio - 1): for _ in range(utd_ratio - 1):
# Sample from the iterators # Sample from the iterators
batch = next(online_iterator) batch = next(online_iterator)
# batch = replay_buffer.sample(batch_size)
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = next(offline_iterator) batch_offline = next(offline_iterator)
@@ -408,7 +407,6 @@ def add_actor_information_and_train(
"done": done, "done": done,
"observation_feature": observation_features, "observation_feature": observation_features,
"next_observation_feature": next_observation_features, "next_observation_feature": next_observation_features,
"complementary_info": batch.get("complementary_info", None),
} }
# Use the forward method for critic loss (includes both main critic and grasp critic) # Use the forward method for critic loss (includes both main critic and grasp critic)
@@ -439,11 +437,9 @@ def add_actor_information_and_train(
# Sample for the last update in the UTD ratio # Sample for the last update in the UTD ratio
batch = next(online_iterator) batch = next(online_iterator)
# batch = replay_buffer.sample(batch_size)
if dataset_repo_id is not None: if dataset_repo_id is not None:
batch_offline = next(offline_iterator) batch_offline = next(offline_iterator)
# batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions( batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline left_batch_transitions=batch, right_batch_transition=batch_offline
) )
@@ -779,7 +775,9 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
params=policy.actor.parameters_to_optimize, params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr, lr=cfg.policy.actor_lr,
) )
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr) optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
)
if cfg.policy.num_discrete_actions is not None: if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam( optimizer_grasp_critic = torch.optim.Adam(
@@ -994,6 +992,7 @@ def initialize_offline_replay_buffer(
device=device, device=device,
state_keys=cfg.policy.input_features.keys(), state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims, action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device, storage_device=storage_device,
optimize_memory=True, optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity, capacity=cfg.policy.offline_buffer_capacity,
@@ -1027,8 +1026,8 @@ def get_observation_features(
return None, None return None, None
with torch.no_grad(): with torch.no_grad():
observation_features = policy.actor.encoder.get_image_features(observations) observation_features = policy.actor.encoder.get_image_features(observations, normalize=True)
next_observation_features = policy.actor.encoder.get_image_features(next_observations) next_observation_features = policy.actor.encoder.get_image_features(next_observations, normalize=True)
return observation_features, next_observation_features return observation_features, next_observation_features