mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
wip: still needs batch logic for act and tdmp
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module):
|
||||
@abstractmethod
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
pass
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
@abstractmethod
|
||||
def select_action(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
Should return a (batch_size, n_action_steps, *) tensor of actions.
|
||||
"""
|
||||
pass
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
WARNING: In general, this should not be overriden.
|
||||
|
||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made:
|
||||
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
n_action_steps_attr = "n_action_steps"
|
||||
if not hasattr(self, n_action_steps_attr):
|
||||
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute")
|
||||
if not hasattr(self, "_action_queue"):
|
||||
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr))
|
||||
if len(self._action_queue) == 0:
|
||||
# Each element in the queue has shape (B, *).
|
||||
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
@@ -2,10 +2,10 @@ import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def kl_divergence(mu, logvar):
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
@@ -147,7 +147,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
|
||||
@@ -3,14 +3,14 @@ import time
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
@@ -44,6 +44,7 @@ class DiffusionPolicy(nn.Module):
|
||||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.n_action_steps = n_action_steps # needed for the parent class
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
@@ -93,21 +94,16 @@ class DiffusionPolicy(nn.Module):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
action = out["action"].squeeze(0)
|
||||
action = out["action"]
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
@@ -85,7 +86,7 @@ class TOLD(nn.Module):
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPC(nn.Module):
|
||||
class TDMPC(AbstractPolicy):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
@@ -124,7 +125,7 @@ class TDMPC(nn.Module):
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
def select_action(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
|
||||
Reference in New Issue
Block a user