chore(processors): tokenizers raises and remove tensor conversion (#1949)

This commit is contained in:
Steven Palma
2025-09-16 11:44:02 +02:00
committed by GitHub
parent b12a386334
commit cf7946e602
2 changed files with 5 additions and 9 deletions
+2 -3
View File
@@ -209,7 +209,7 @@ class _NormalizationMixin:
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
return new_observation
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
def _normalize_action(self, action: Tensor, inverse: bool) -> Tensor:
# Convert to tensor but preserve original dtype for adaptation logic
"""
Applies (un)normalization to an action tensor.
@@ -221,8 +221,7 @@ class _NormalizationMixin:
Returns:
The transformed action tensor.
"""
tensor = torch.as_tensor(action)
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse)
return processed_action
def _apply_transform(
+3 -6
View File
@@ -118,14 +118,11 @@ class TokenizerProcessorStep(ObservationProcessorStep):
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
if self.task_key not in complementary_data:
return None
raise ValueError("Complementary data is None so no task can be extracted from it")
task = complementary_data[self.task_key]
if task is None:
return None
raise ValueError("Task extracted from Complementary data is None")
# Standardize to a list of strings for the tokenizer
if isinstance(task, str):
@@ -150,7 +147,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
"""
task = self.get_task(self.transition)
if task is None:
return observation
raise ValueError("Task cannot be None")
# Tokenize the task (this will create CPU tensors)
tokenized_prompt = self._tokenize_text(task)