mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
use constants for indexing into batches and remove env state references
This commit is contained in:
@@ -42,7 +42,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
# -- Policy --
|
||||
|
||||
@@ -107,7 +107,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
]
|
||||
|
||||
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
@@ -120,12 +120,12 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@@ -265,11 +265,6 @@ class ObservationEncoder(nn.Module):
|
||||
else:
|
||||
self.robot_state_dim = 0
|
||||
|
||||
if hasattr(config, "env_state_feature") and config.env_state_feature:
|
||||
self.env_state_dim = config.env_state_feature.shape[0]
|
||||
else:
|
||||
self.env_state_dim = 0
|
||||
|
||||
self.text_dim = config.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||
|
||||
@@ -314,7 +309,6 @@ class ObservationEncoder(nn.Module):
|
||||
total_dim += spatial_feature_dim * self.num_cameras
|
||||
|
||||
total_dim += self.robot_state_dim
|
||||
total_dim += self.env_state_dim
|
||||
total_dim += self.text_dim
|
||||
|
||||
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
||||
@@ -355,9 +349,6 @@ class ObservationEncoder(nn.Module):
|
||||
)
|
||||
conditioning_feats.append(img_features)
|
||||
|
||||
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
|
||||
conditioning_feats.append(batch[OBS_ENV_STATE])
|
||||
|
||||
if self.text_encoder is not None and "task" in batch:
|
||||
text_features = self.text_encoder(batch["task"])
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
||||
@@ -664,7 +655,7 @@ class DiffusionObjective(BaseObjective):
|
||||
)
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
clean_actions = batch["action"]
|
||||
clean_actions = batch[ACTION]
|
||||
noise = torch.randn_like(clean_actions)
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
@@ -733,7 +724,7 @@ class FlowMatchingObjective(BaseObjective):
|
||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
data = batch["action"]
|
||||
data = batch[ACTION]
|
||||
batch_size = data.shape[0]
|
||||
device = data.device
|
||||
|
||||
|
||||
Reference in New Issue
Block a user