mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
refactor code to perform task tokenization in the processor instead of in the modeling code for multitask dit
This commit is contained in:
@@ -79,6 +79,10 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Text Encoder (CLIP)
|
# Text Encoder (CLIP)
|
||||||
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||||
|
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
|
||||||
|
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
|
||||||
|
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
|
||||||
|
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
|
||||||
|
|
||||||
# Normalization
|
# Normalization
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
|||||||
@@ -36,12 +36,18 @@ import torchvision
|
|||||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
from transformers import CLIPTextModel, CLIPVisionModel
|
||||||
|
|
||||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import populate_queues
|
from lerobot.policies.utils import populate_queues
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import (
|
||||||
|
ACTION,
|
||||||
|
OBS_IMAGES,
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_STATE,
|
||||||
|
)
|
||||||
|
|
||||||
# -- Policy --
|
# -- Policy --
|
||||||
|
|
||||||
@@ -127,36 +133,34 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations"""
|
"""Predict a chunk of actions given environment observations"""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
original_batch_keys = set(batch.keys())
|
for k in batch:
|
||||||
new_batch = {}
|
if k in self._queues:
|
||||||
for k in self._queues:
|
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||||
if k in original_batch_keys:
|
|
||||||
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
|
|
||||||
new_batch[k] = self._queues[k][-1]
|
|
||||||
else:
|
|
||||||
queue_values = list(self._queues[k])
|
|
||||||
new_batch[k] = torch.stack(queue_values, dim=1)
|
|
||||||
batch = new_batch
|
|
||||||
|
|
||||||
actions = self._generate_actions(batch)
|
actions = self._generate_actions(batch)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
"""Prepare batch by stacking image features if needed."""
|
||||||
|
if self.config.image_features:
|
||||||
|
batch = dict(batch) # shallow copy to avoid modifying original
|
||||||
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations"""
|
"""Select a single action given environment observations"""
|
||||||
if ACTION in batch:
|
if ACTION in batch:
|
||||||
|
batch = dict(batch) # shallow copy to avoid modifying original
|
||||||
batch.pop(ACTION)
|
batch.pop(ACTION)
|
||||||
|
|
||||||
if self.config.image_features:
|
batch = self._prepare_batch(batch)
|
||||||
batch = dict(batch)
|
|
||||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
|
||||||
|
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
@@ -169,9 +173,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||||
"""Run the batch through the model and compute the loss for training"""
|
"""Run the batch through the model and compute the loss for training"""
|
||||||
if self.config.image_features:
|
batch = self._prepare_batch(batch)
|
||||||
batch = dict(batch)
|
|
||||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
|
||||||
|
|
||||||
conditioning_vec = self.observation_encoder.encode(batch)
|
conditioning_vec = self.observation_encoder.encode(batch)
|
||||||
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
||||||
@@ -204,13 +206,16 @@ class CLIPVisionEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoder(nn.Module):
|
class CLIPTextEncoder(nn.Module):
|
||||||
"""CLIP text encoder with frozen weights and a learnable projection layer."""
|
"""CLIP text encoder with frozen weights and a learnable projection layer.
|
||||||
|
|
||||||
|
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
|
||||||
|
pipeline to see how the tokenization is handled.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.projection_dim = projection_dim
|
self.projection_dim = projection_dim
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
|
||||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
||||||
|
|
||||||
for param in self.text_encoder.parameters():
|
for param in self.text_encoder.parameters():
|
||||||
@@ -219,16 +224,15 @@ class CLIPTextEncoder(nn.Module):
|
|||||||
self.text_embed_dim = self.text_encoder.config.hidden_size
|
self.text_embed_dim = self.text_encoder.config.hidden_size
|
||||||
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
||||||
|
|
||||||
def forward(self, text: str | list[str]) -> Tensor:
|
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||||
"""Encode text to feature vectors."""
|
"""Encode pre-tokenized text to feature vectors."""
|
||||||
if isinstance(text, str):
|
# Ensure inputs are on the same device as the model
|
||||||
text = [text]
|
device = next(self.parameters()).device
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
attention_mask = attention_mask.to(device)
|
||||||
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.text_encoder(**text_inputs)
|
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
clip_features = outputs.pooler_output
|
clip_features = outputs.pooler_output
|
||||||
|
|
||||||
return self.projection(clip_features)
|
return self.projection(clip_features)
|
||||||
@@ -349,8 +353,12 @@ class ObservationEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
conditioning_feats.append(img_features)
|
conditioning_feats.append(img_features)
|
||||||
|
|
||||||
if self.text_encoder is not None and "task" in batch:
|
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
|
||||||
text_features = self.text_encoder(batch["task"])
|
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
|
||||||
|
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
|
||||||
|
|
||||||
|
text_features = self.text_encoder(input_ids, attention_mask)
|
||||||
|
|
||||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
||||||
conditioning_feats.append(text_features)
|
conditioning_feats.append(text_features)
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from lerobot.processor import (
|
|||||||
PolicyAction,
|
PolicyAction,
|
||||||
PolicyProcessorPipeline,
|
PolicyProcessorPipeline,
|
||||||
RenameObservationsProcessorStep,
|
RenameObservationsProcessorStep,
|
||||||
|
TokenizerProcessorStep,
|
||||||
UnnormalizerProcessorStep,
|
UnnormalizerProcessorStep,
|
||||||
)
|
)
|
||||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
@@ -45,8 +46,9 @@ def make_multi_task_dit_pre_post_processors(
|
|||||||
The pre-processing pipeline prepares the input data for the model by:
|
The pre-processing pipeline prepares the input data for the model by:
|
||||||
1. Renaming features.
|
1. Renaming features.
|
||||||
2. Adding a batch dimension.
|
2. Adding a batch dimension.
|
||||||
3. Moving the data to the specified device.
|
3. Tokenizing the language task description (if present).
|
||||||
4. Normalizing the input and output features based on dataset statistics.
|
4. Moving the data to the specified device.
|
||||||
|
5. Normalizing the input and output features based on dataset statistics.
|
||||||
|
|
||||||
The post-processing pipeline handles the model's output by:
|
The post-processing pipeline handles the model's output by:
|
||||||
1. Unnormalizing the output features to their original scale.
|
1. Unnormalizing the output features to their original scale.
|
||||||
@@ -65,6 +67,13 @@ def make_multi_task_dit_pre_post_processors(
|
|||||||
input_steps = [
|
input_steps = [
|
||||||
RenameObservationsProcessorStep(rename_map={}),
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
AddBatchDimensionProcessorStep(),
|
AddBatchDimensionProcessorStep(),
|
||||||
|
TokenizerProcessorStep(
|
||||||
|
tokenizer_name=config.text_encoder_name,
|
||||||
|
padding=config.tokenizer_padding,
|
||||||
|
padding_side=config.tokenizer_padding_side,
|
||||||
|
max_length=config.tokenizer_max_length,
|
||||||
|
truncation=config.tokenizer_truncation,
|
||||||
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
NormalizerProcessorStep(
|
NormalizerProcessorStep(
|
||||||
features={**config.input_features, **config.output_features},
|
features={**config.input_features, **config.output_features},
|
||||||
|
|||||||
@@ -30,7 +30,13 @@ from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiT
|
|||||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||||
make_multi_task_dit_pre_post_processors,
|
make_multi_task_dit_pre_post_processors,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import (
|
||||||
|
ACTION,
|
||||||
|
OBS_IMAGES,
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_STATE,
|
||||||
|
)
|
||||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||||
|
|
||||||
|
|
||||||
@@ -132,6 +138,14 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
|||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
|
# Use preprocessor to handle tokenization
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||||
|
|
||||||
batch = create_train_batch(
|
batch = create_train_batch(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
@@ -140,8 +154,11 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
|||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Process batch through preprocessor to tokenize task text
|
||||||
|
processed_batch = preprocessor(batch)
|
||||||
|
|
||||||
# Test forward pass
|
# Test forward pass
|
||||||
loss, _ = policy.forward(batch)
|
loss, _ = policy.forward(processed_batch)
|
||||||
assert loss is not None
|
assert loss is not None
|
||||||
assert loss.item() is not None
|
assert loss.item() is not None
|
||||||
assert loss.shape == ()
|
assert loss.shape == ()
|
||||||
@@ -204,7 +221,11 @@ def test_multi_task_dit_pre_post_processors():
|
|||||||
assert processed_batch["observation.state"].shape == (1, state_dim)
|
assert processed_batch["observation.state"].shape == (1, state_dim)
|
||||||
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
|
assert processed_batch[f"{OBS_IMAGES}.laptop"].shape == (1, 3, 224, 224)
|
||||||
assert processed_batch[ACTION].shape == (1, action_dim)
|
assert processed_batch[ACTION].shape == (1, action_dim)
|
||||||
assert "task" in processed_batch
|
# Check that task text was tokenized
|
||||||
|
assert OBS_LANGUAGE_TOKENS in processed_batch
|
||||||
|
assert OBS_LANGUAGE_ATTENTION_MASK in processed_batch
|
||||||
|
assert processed_batch[OBS_LANGUAGE_TOKENS].shape[0] == 1 # batch dimension
|
||||||
|
assert processed_batch[OBS_LANGUAGE_ATTENTION_MASK].shape[0] == 1 # batch dimension
|
||||||
|
|
||||||
# Check that data is on correct device
|
# Check that data is on correct device
|
||||||
assert processed_batch["observation.state"].device.type == "cpu"
|
assert processed_batch["observation.state"].device.type == "cpu"
|
||||||
@@ -360,6 +381,14 @@ def test_multi_task_dit_policy_diffusion_objective():
|
|||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
|
# Use preprocessor to handle tokenization
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||||
|
|
||||||
batch = create_train_batch(
|
batch = create_train_batch(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
@@ -368,8 +397,11 @@ def test_multi_task_dit_policy_diffusion_objective():
|
|||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Process batch through preprocessor to tokenize task text
|
||||||
|
processed_batch = preprocessor(batch)
|
||||||
|
|
||||||
# Test forward pass
|
# Test forward pass
|
||||||
loss, _ = policy.forward(batch)
|
loss, _ = policy.forward(processed_batch)
|
||||||
assert loss is not None
|
assert loss is not None
|
||||||
assert loss.item() is not None
|
assert loss.item() is not None
|
||||||
|
|
||||||
@@ -427,6 +459,14 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
|||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
|
# Use preprocessor to handle tokenization
|
||||||
|
config.normalization_mapping = {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
preprocessor, _ = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||||
|
|
||||||
batch = create_train_batch(
|
batch = create_train_batch(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_obs_steps=n_obs_steps,
|
n_obs_steps=n_obs_steps,
|
||||||
@@ -435,8 +475,11 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
|||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Process batch through preprocessor to tokenize task text
|
||||||
|
processed_batch = preprocessor(batch)
|
||||||
|
|
||||||
# Test forward pass
|
# Test forward pass
|
||||||
loss, _ = policy.forward(batch)
|
loss, _ = policy.forward(processed_batch)
|
||||||
assert loss is not None
|
assert loss is not None
|
||||||
assert loss.item() is not None
|
assert loss.item() is not None
|
||||||
|
|
||||||
@@ -499,33 +542,39 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
|||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move batch to the same device as the policy
|
# Use preprocessor to handle tokenization
|
||||||
for key in batch:
|
config.normalization_mapping = {
|
||||||
if isinstance(batch[key], torch.Tensor):
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
batch[key] = batch[key].to(device)
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
preprocessor, postprocessor = make_multi_task_dit_pre_post_processors(config=config, dataset_stats=None)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with seeded_context(12):
|
with seeded_context(12):
|
||||||
|
# Process batch through preprocessor
|
||||||
|
processed_batch = preprocessor(batch)
|
||||||
|
# Move batch to the same device as the policy
|
||||||
|
for key in processed_batch:
|
||||||
|
if isinstance(processed_batch[key], torch.Tensor):
|
||||||
|
processed_batch[key] = processed_batch[key].to(device)
|
||||||
# Collect policy values before saving
|
# Collect policy values before saving
|
||||||
loss, _ = policy.forward(batch)
|
loss, _ = policy.forward(processed_batch)
|
||||||
|
|
||||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||||
# Move observation batch to device
|
# Process observation through preprocessor
|
||||||
for key in observation_batch:
|
processed_obs = preprocessor(observation_batch)
|
||||||
if isinstance(observation_batch[key], torch.Tensor):
|
actions = policy.select_action(processed_obs)
|
||||||
observation_batch[key] = observation_batch[key].to(device)
|
|
||||||
actions = policy.select_action(observation_batch)
|
|
||||||
|
|
||||||
with seeded_context(12):
|
with seeded_context(12):
|
||||||
|
# Process batch through preprocessor
|
||||||
|
processed_batch = preprocessor(batch)
|
||||||
# Collect policy values after loading
|
# Collect policy values after loading
|
||||||
loaded_loss, _ = loaded_policy.forward(batch)
|
loaded_loss, _ = loaded_policy.forward(processed_batch)
|
||||||
|
|
||||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||||
# Move observation batch to device
|
processed_obs = preprocessor(loaded_observation_batch)
|
||||||
for key in loaded_observation_batch:
|
loaded_actions = loaded_policy.select_action(processed_obs)
|
||||||
if isinstance(loaded_observation_batch[key], torch.Tensor):
|
|
||||||
loaded_observation_batch[key] = loaded_observation_batch[key].to(device)
|
|
||||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
|
||||||
|
|
||||||
# Compare state dicts
|
# Compare state dicts
|
||||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||||
|
|||||||
Reference in New Issue
Block a user