mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
use constants for indexing into batches and remove env state references
This commit is contained in:
@@ -42,7 +42,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
|||||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import populate_queues
|
from lerobot.policies.utils import populate_queues
|
||||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
# -- Policy --
|
# -- Policy --
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
conditioning_vec = self.observation_encoder.encode(batch)
|
conditioning_vec = self.observation_encoder.encode(batch)
|
||||||
@@ -120,12 +120,12 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||||
self._queues = {
|
self._queues = {
|
||||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||||
"action": deque(maxlen=self.config.n_action_steps),
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
@@ -265,11 +265,6 @@ class ObservationEncoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.robot_state_dim = 0
|
self.robot_state_dim = 0
|
||||||
|
|
||||||
if hasattr(config, "env_state_feature") and config.env_state_feature:
|
|
||||||
self.env_state_dim = config.env_state_feature.shape[0]
|
|
||||||
else:
|
|
||||||
self.env_state_dim = 0
|
|
||||||
|
|
||||||
self.text_dim = config.hidden_dim
|
self.text_dim = config.hidden_dim
|
||||||
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||||
|
|
||||||
@@ -314,7 +309,6 @@ class ObservationEncoder(nn.Module):
|
|||||||
total_dim += spatial_feature_dim * self.num_cameras
|
total_dim += spatial_feature_dim * self.num_cameras
|
||||||
|
|
||||||
total_dim += self.robot_state_dim
|
total_dim += self.robot_state_dim
|
||||||
total_dim += self.env_state_dim
|
|
||||||
total_dim += self.text_dim
|
total_dim += self.text_dim
|
||||||
|
|
||||||
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
||||||
@@ -355,9 +349,6 @@ class ObservationEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
conditioning_feats.append(img_features)
|
conditioning_feats.append(img_features)
|
||||||
|
|
||||||
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
|
|
||||||
conditioning_feats.append(batch[OBS_ENV_STATE])
|
|
||||||
|
|
||||||
if self.text_encoder is not None and "task" in batch:
|
if self.text_encoder is not None and "task" in batch:
|
||||||
text_features = self.text_encoder(batch["task"])
|
text_features = self.text_encoder(batch["task"])
|
||||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
||||||
@@ -664,7 +655,7 @@ class DiffusionObjective(BaseObjective):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||||
clean_actions = batch["action"]
|
clean_actions = batch[ACTION]
|
||||||
noise = torch.randn_like(clean_actions)
|
noise = torch.randn_like(clean_actions)
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
@@ -733,7 +724,7 @@ class FlowMatchingObjective(BaseObjective):
|
|||||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
||||||
|
|
||||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||||
data = batch["action"]
|
data = batch[ACTION]
|
||||||
batch_size = data.shape[0]
|
batch_size = data.shape[0]
|
||||||
device = data.device
|
device = data.device
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user