diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index dbe4910b8..757da7c17 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -42,7 +42,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import populate_queues -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE # -- Policy -- @@ -107,7 +107,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): ] def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor: - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] assert n_obs_steps == self.config.n_obs_steps conditioning_vec = self.observation_encoder.encode(batch) @@ -120,12 +120,12 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { - "observation.state": deque(maxlen=self.config.n_obs_steps), - "action": deque(maxlen=self.config.n_action_steps), + OBS_STATE: deque(maxlen=self.config.n_obs_steps), + ACTION: deque(maxlen=self.config.n_action_steps), } if self.config.image_features: - self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) self._queues["task"] = deque(maxlen=self.config.n_obs_steps) @@ -265,11 +265,6 @@ class ObservationEncoder(nn.Module): else: self.robot_state_dim = 0 - if hasattr(config, "env_state_feature") and config.env_state_feature: - self.env_state_dim = config.env_state_feature.shape[0] - else: - self.env_state_dim = 0 - self.text_dim = config.hidden_dim self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim) @@ -314,7 +309,6 @@ class ObservationEncoder(nn.Module): total_dim += spatial_feature_dim * self.num_cameras total_dim += self.robot_state_dim - total_dim += self.env_state_dim total_dim += self.text_dim self.conditioning_dim = total_dim * self.config.n_obs_steps @@ -355,9 +349,6 @@ class ObservationEncoder(nn.Module): ) conditioning_feats.append(img_features) - if self.env_state_dim > 0 and OBS_ENV_STATE in batch: - conditioning_feats.append(batch[OBS_ENV_STATE]) - if self.text_encoder is not None and "task" in batch: text_features = self.text_encoder(batch["task"]) text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) @@ -664,7 +655,7 @@ class DiffusionObjective(BaseObjective): ) def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: - clean_actions = batch["action"] + clean_actions = batch[ACTION] noise = torch.randn_like(clean_actions) timesteps = torch.randint( low=0, @@ -733,7 +724,7 @@ class FlowMatchingObjective(BaseObjective): raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}") def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: - data = batch["action"] + data = batch[ACTION] batch_size = data.shape[0] device = data.device