mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
loss naming
This commit is contained in:
@@ -1009,7 +1009,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
flow_loss = F.mse_loss(u_t, v_t, reduction="none")
|
flow_loss = F.mse_loss(u_t, v_t, reduction="none")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"flow_loss": flow_loss.mean(),
|
"flow_mse_loss": flow_loss.mean(),
|
||||||
"action_ce_loss": fast_loss.mean(),
|
"action_ce_loss": fast_loss.mean(),
|
||||||
"subtask_ce_loss": subtask_loss,
|
"subtask_ce_loss": subtask_loss,
|
||||||
"loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights
|
"loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights
|
||||||
@@ -1501,7 +1501,7 @@ class PI05FullPolicy(PreTrainedPolicy):
|
|||||||
# Prepare detailed loss dictionary for logging
|
# Prepare detailed loss dictionary for logging
|
||||||
detailed_loss_dict = {
|
detailed_loss_dict = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
"flow_mse_loss": loss_dict["flow_loss"].mean().item(),
|
||||||
"subtask_ce_loss": loss_dict["subtask_ce_loss"].item(),
|
"subtask_ce_loss": loss_dict["subtask_ce_loss"].item(),
|
||||||
"action_ce_loss": loss_dict["action_ce_loss"].item(),
|
"action_ce_loss": loss_dict["action_ce_loss"].item(),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user