use constants for indexing into batches and remove env state references

This commit is contained in:
Bryson Jones
2025-12-11 09:13:38 -08:00
parent 9b47c5fac9
commit c398a146b3
@@ -42,7 +42,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
# -- Policy --
@@ -107,7 +107,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
]
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
@@ -120,12 +120,12 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
@@ -265,11 +265,6 @@ class ObservationEncoder(nn.Module):
else:
self.robot_state_dim = 0
if hasattr(config, "env_state_feature") and config.env_state_feature:
self.env_state_dim = config.env_state_feature.shape[0]
else:
self.env_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
@@ -314,7 +309,6 @@ class ObservationEncoder(nn.Module):
total_dim += spatial_feature_dim * self.num_cameras
total_dim += self.robot_state_dim
total_dim += self.env_state_dim
total_dim += self.text_dim
self.conditioning_dim = total_dim * self.config.n_obs_steps
@@ -355,9 +349,6 @@ class ObservationEncoder(nn.Module):
)
conditioning_feats.append(img_features)
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
conditioning_feats.append(batch[OBS_ENV_STATE])
if self.text_encoder is not None and "task" in batch:
text_features = self.text_encoder(batch["task"])
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
@@ -664,7 +655,7 @@ class DiffusionObjective(BaseObjective):
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch["action"]
clean_actions = batch[ACTION]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
@@ -733,7 +724,7 @@ class FlowMatchingObjective(BaseObjective):
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
data = batch["action"]
data = batch[ACTION]
batch_size = data.shape[0]
device = data.device