mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
fix(smolvla2): train on rendered language batches
Keep annotated language columns through collation, render batched recipe samples, and make SmolVLA2 text loss robust enough for distributed training on the steerable dataset. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -62,14 +62,14 @@ blend:
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.top}
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ from .dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
@@ -53,6 +52,19 @@ from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
|
||||
def make_dataset(*args, **kwargs):
|
||||
from .factory import make_dataset as _make_dataset
|
||||
|
||||
return _make_dataset(*args, **kwargs)
|
||||
|
||||
|
||||
def resolve_delta_timestamps(*args, **kwargs):
|
||||
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
|
||||
|
||||
return _resolve_delta_timestamps(*args, **kwargs)
|
||||
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
# Import directly: ``from lerobot.datasets.io_utils import ...``
|
||||
|
||||
@@ -64,8 +64,22 @@ def _json_arrow_type() -> pa.DataType:
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
"""Return the HF feature used for tool-call payloads.
|
||||
|
||||
Older ``datasets`` versions do not expose ``datasets.Json``. The
|
||||
annotation pipeline currently emits the canonical ``say`` tool call
|
||||
shape, so use that explicit struct instead of falling back to a string
|
||||
that cannot cast structured parquet values.
|
||||
"""
|
||||
if hasattr(datasets, "Json"):
|
||||
return datasets.Json()
|
||||
return {
|
||||
"type": datasets.Value("string"),
|
||||
"function": {
|
||||
"name": datasets.Value("string"),
|
||||
"arguments": {"text": datasets.Value("string")},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
|
||||
@@ -26,6 +26,7 @@ from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -49,6 +50,7 @@ __all__ = [
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"SmolVLA2Config",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"WallXConfig",
|
||||
|
||||
@@ -99,85 +99,67 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
# falls back to whatever ``task`` is in the transition.
|
||||
return transition
|
||||
|
||||
message_streams: list[str | None] = list(comp.get("message_streams") or [])
|
||||
target_indices: list[int] = sorted(
|
||||
int(i) for i in (comp.get("target_message_indices") or [])
|
||||
)
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||
|
||||
# Tokenize the full chat once.
|
||||
full_ids = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
if _is_batched_messages(messages):
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
msg,
|
||||
list(streams),
|
||||
sorted(int(i) for i in indices),
|
||||
)
|
||||
for msg, streams, indices in zip(
|
||||
messages,
|
||||
comp.get("message_streams") or [[] for _ in messages],
|
||||
comp.get("target_message_indices") or [[] for _ in messages],
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
else:
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
messages,
|
||||
list(comp.get("message_streams") or []),
|
||||
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
|
||||
)
|
||||
]
|
||||
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
target_length = self.max_length if self.padding == "max_length" else max(
|
||||
len(ids) for ids, _, _ in encoded
|
||||
)
|
||||
if isinstance(full_ids, list) and full_ids and isinstance(full_ids[0], list):
|
||||
full_ids = full_ids[0]
|
||||
target_length = min(target_length, self.max_length)
|
||||
|
||||
# Build the label mask by re-rendering progressively up to each
|
||||
# target message and reading off the prefix length. This is the
|
||||
# robust way to get exact token boundaries: we use the same
|
||||
# tokenizer, the same ``tools=`` argument, and the same chat
|
||||
# template — so the prefix tokens are guaranteed to be a prefix
|
||||
# of the full sequence.
|
||||
labels = [-100] * len(full_ids)
|
||||
for tgt in target_indices:
|
||||
prefix_ids = tokenizer.apply_chat_template(
|
||||
text_messages[:tgt],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_through_target = tokenizer.apply_chat_template(
|
||||
text_messages[: tgt + 1],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
if isinstance(prefix_ids, list) and prefix_ids and isinstance(prefix_ids[0], list):
|
||||
prefix_ids = prefix_ids[0]
|
||||
if (
|
||||
isinstance(full_through_target, list)
|
||||
and full_through_target
|
||||
and isinstance(full_through_target[0], list)
|
||||
):
|
||||
full_through_target = full_through_target[0]
|
||||
start = len(prefix_ids)
|
||||
end = min(len(full_through_target), len(full_ids))
|
||||
for pos in range(start, end):
|
||||
labels[pos] = int(full_ids[pos])
|
||||
ids_batch = []
|
||||
attn_batch = []
|
||||
labels_batch = []
|
||||
predict_actions = []
|
||||
for ids, labels, predict_action in encoded:
|
||||
ids = ids[:target_length]
|
||||
labels = labels[:target_length]
|
||||
attn = [1] * len(ids)
|
||||
if len(ids) < target_length:
|
||||
n_pad = target_length - len(ids)
|
||||
ids = ids + [pad_id] * n_pad
|
||||
labels = labels + [-100] * n_pad
|
||||
attn = attn + [0] * n_pad
|
||||
ids_batch.append(ids)
|
||||
attn_batch.append(attn)
|
||||
labels_batch.append(labels)
|
||||
predict_actions.append(predict_action)
|
||||
|
||||
# Truncate / pad to ``max_length`` so batches collate cleanly.
|
||||
# The SmolVLA pipeline downstream relies on a fixed length
|
||||
# behaviour ("longest" or "max_length") — we mirror it here.
|
||||
if len(full_ids) > self.max_length:
|
||||
full_ids = full_ids[: self.max_length]
|
||||
labels = labels[: self.max_length]
|
||||
attn = [1] * len(full_ids)
|
||||
if self.padding == "max_length" and len(full_ids) < self.max_length:
|
||||
pad_id = (
|
||||
tokenizer.pad_token_id
|
||||
if tokenizer.pad_token_id is not None
|
||||
else 0
|
||||
)
|
||||
n_pad = self.max_length - len(full_ids)
|
||||
full_ids = full_ids + [pad_id] * n_pad
|
||||
labels = labels + [-100] * n_pad
|
||||
attn = attn + [0] * n_pad
|
||||
ids_t = torch.tensor(ids_batch, dtype=torch.long)
|
||||
attn_t = torch.tensor(attn_batch, dtype=torch.bool)
|
||||
labels_t = torch.tensor(labels_batch, dtype=torch.long)
|
||||
predict_actions_t = torch.tensor(predict_actions, dtype=torch.bool)
|
||||
|
||||
ids_t = torch.tensor(full_ids, dtype=torch.long)
|
||||
attn_t = torch.tensor(attn, dtype=torch.bool)
|
||||
labels_t = torch.tensor(labels, dtype=torch.long)
|
||||
predict_actions = any(
|
||||
i < len(message_streams) and message_streams[i] == "low_level"
|
||||
for i in target_indices
|
||||
)
|
||||
if not _is_batched_messages(messages):
|
||||
ids_t = ids_t.squeeze(0)
|
||||
attn_t = attn_t.squeeze(0)
|
||||
labels_t = labels_t.squeeze(0)
|
||||
predict_actions_t = predict_actions_t.squeeze(0)
|
||||
|
||||
new_complementary = dict(comp)
|
||||
# Drop the per-recipe sidecar keys; everything downstream needs
|
||||
@@ -194,7 +176,7 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
observation[OBS_LANGUAGE_TOKENS] = ids_t
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
|
||||
new_complementary["text_labels"] = labels_t
|
||||
new_complementary["predict_actions"] = torch.tensor(predict_actions, dtype=torch.bool)
|
||||
new_complementary["predict_actions"] = predict_actions_t
|
||||
new_complementary.pop("task", None)
|
||||
|
||||
new_transition = dict(transition)
|
||||
@@ -202,6 +184,53 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
def _encode_messages(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
message_streams: list[str | None],
|
||||
target_indices: list[int],
|
||||
) -> tuple[list[int], list[int], bool]:
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||
|
||||
full_ids = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_ids = _as_token_ids(full_ids)
|
||||
|
||||
labels = [-100] * len(full_ids)
|
||||
for tgt in target_indices:
|
||||
prefix_ids = tokenizer.apply_chat_template(
|
||||
text_messages[:tgt],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_through_target = tokenizer.apply_chat_template(
|
||||
text_messages[: tgt + 1],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
prefix_ids = _as_token_ids(prefix_ids)
|
||||
full_through_target = _as_token_ids(full_through_target)
|
||||
start = len(prefix_ids)
|
||||
end = min(len(full_through_target), len(full_ids))
|
||||
for pos in range(start, end):
|
||||
labels[pos] = int(full_ids[pos])
|
||||
|
||||
predict_actions = any(
|
||||
i < len(message_streams) and message_streams[i] == "low_level"
|
||||
for i in target_indices
|
||||
)
|
||||
return [int(i) for i in full_ids], labels, predict_actions
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
@@ -247,15 +276,11 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
|
||||
# If only one text block survives, flatten to a string for
|
||||
# template friendliness; some chat templates choke on a single-
|
||||
# element list.
|
||||
if len(text_parts) == 1:
|
||||
new["content"] = text_parts[0]["text"]
|
||||
elif text_parts:
|
||||
new["content"] = text_parts
|
||||
else:
|
||||
new["content"] = ""
|
||||
new["content"] = text_parts or [{"type": "text", "text": ""}]
|
||||
elif content is None:
|
||||
new["content"] = [{"type": "text", "text": ""}]
|
||||
else:
|
||||
new["content"] = [{"type": "text", "text": str(content)}]
|
||||
if "tool_calls" in new and not new["tool_calls"]:
|
||||
# Drop empty tool_calls — some templates render them as a
|
||||
# spurious empty marker.
|
||||
@@ -267,5 +292,19 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
return new
|
||||
|
||||
|
||||
def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
|
||||
def _as_token_ids(value: Any) -> list[int]:
|
||||
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
|
||||
value = value["input_ids"]
|
||||
if hasattr(value, "tolist"):
|
||||
value = value.tolist()
|
||||
if isinstance(value, list) and value and isinstance(value[0], list):
|
||||
value = value[0]
|
||||
return [int(i) for i in value]
|
||||
|
||||
|
||||
# Re-export for tests / introspection
|
||||
strip_lerobot_blocks = _strip_lerobot_blocks
|
||||
|
||||
@@ -60,7 +60,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
config_class = SmolVLA2Config
|
||||
name = "smolvla2"
|
||||
|
||||
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(self, config: SmolVLA2Config, **kwargs):
|
||||
if not isinstance(config, SmolVLA2Config):
|
||||
config = SmolVLA2Config(
|
||||
**{
|
||||
@@ -69,7 +69,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
if hasattr(config, f.name)
|
||||
}
|
||||
)
|
||||
super().__init__(config, dataset_stats=dataset_stats)
|
||||
super().__init__(config, **kwargs)
|
||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||
self._unfreeze_lm_head()
|
||||
|
||||
@@ -200,7 +200,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
@@ -228,8 +228,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
f"num_state={num_state})."
|
||||
)
|
||||
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end]
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||
|
||||
if text_labels.shape[1] != num_lang:
|
||||
@@ -244,12 +244,14 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
# for the same convention.
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
valid_labels = shift_labels != -100
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
)
|
||||
return loss
|
||||
return loss / valid_labels.sum().clamp(min=1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference: text generation
|
||||
|
||||
@@ -175,9 +175,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
|
||||
@@ -51,6 +51,9 @@ class RenderMessagesStep(ProcessorStep):
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
if _is_batched_language(persistent) or _is_batched_language(events):
|
||||
return self._call_batch(transition, complementary_data, persistent, events)
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
@@ -69,13 +72,64 @@ class RenderMessagesStep(ProcessorStep):
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def _call_batch(
|
||||
self,
|
||||
transition: EnvTransition,
|
||||
complementary_data: dict[str, Any],
|
||||
persistent_batch: list,
|
||||
events_batch: list,
|
||||
) -> EnvTransition | None:
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
batch_size = max(len(persistent_batch), len(events_batch))
|
||||
messages: list[list[dict[str, Any]]] = []
|
||||
message_streams: list[list[str | None]] = []
|
||||
target_message_indices: list[list[int]] = []
|
||||
keep_indices: list[int] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
|
||||
events=events_batch[i] if i < len(events_batch) else [],
|
||||
t=_batch_value(timestamp, i),
|
||||
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
|
||||
task=_batch_value(complementary_data.get("task"), i),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
continue
|
||||
keep_indices.append(i)
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
new_transition = (
|
||||
_select_batch_indices(transition, keep_indices)
|
||||
if len(keep_indices) != batch_size
|
||||
else transition.copy()
|
||||
)
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data["messages"] = messages
|
||||
new_complementary_data["message_streams"] = message_streams
|
||||
new_complementary_data["target_message_indices"] = target_message_indices
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
@@ -90,3 +144,37 @@ def _scalar(value: Any) -> float | int:
|
||||
if isinstance(value, list) and len(value) == 1:
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
|
||||
def _is_batched_language(value: Any) -> bool:
|
||||
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
|
||||
|
||||
|
||||
def _batch_value(value: Any, index: int) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return value[index]
|
||||
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
|
||||
return _scalar(value[index])
|
||||
return _scalar(value)
|
||||
|
||||
|
||||
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
|
||||
selected = transition.copy()
|
||||
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
|
||||
data = selected.get(key)
|
||||
if isinstance(data, dict):
|
||||
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
|
||||
action = selected.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
selected[TransitionKey.ACTION] = _select_value(action, indices)
|
||||
return selected
|
||||
|
||||
|
||||
def _select_value(value: Any, indices: list[int]) -> Any:
|
||||
if isinstance(value, list) and len(value) >= len(indices):
|
||||
return [value[i] for i in indices]
|
||||
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
||||
return value.index_select(0, value.new_tensor(indices).long())
|
||||
return value
|
||||
|
||||
@@ -22,7 +22,7 @@ from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
||||
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS}
|
||||
|
||||
|
||||
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||
|
||||
Reference in New Issue
Block a user