mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
remove redundant asserts
This commit is contained in:
@@ -804,11 +804,6 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
batch = dict(batch)
|
batch = dict(batch)
|
||||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
|
|
||||||
n_obs_steps = batch["observation.state"].shape[1]
|
|
||||||
horizon = batch["action"].shape[1]
|
|
||||||
assert horizon == self.config.horizon
|
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
|
||||||
|
|
||||||
conditioning_vec = self.observation_encoder.encode(batch)
|
conditioning_vec = self.observation_encoder.encode(batch)
|
||||||
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user