From a1b8134ef168ec360c0528c5116d8819df134dce Mon Sep 17 00:00:00 2001 From: pepijn Date: Tue, 5 May 2026 08:55:56 +0000 Subject: [PATCH] fix(smolvla2): train on rendered language batches Keep annotated language columns through collation, render batched recipe samples, and make SmolVLA2 text loss robust enough for distributed training on the steerable dataset. Co-authored-by: Cursor --- .../configs/recipes/smolvla2_hirobot.yaml | 6 +- src/lerobot/datasets/__init__.py | 14 +- src/lerobot/datasets/language.py | 18 +- src/lerobot/policies/__init__.py | 2 + .../smolvla2/chat_processor_smolvla2.py | 205 +++++++++++------- .../policies/smolvla2/modeling_smolvla2.py | 12 +- src/lerobot/processor/batch_processor.py | 3 - .../processor/render_messages_processor.py | 90 +++++++- src/lerobot/utils/collate.py | 2 +- 9 files changed, 253 insertions(+), 99 deletions(-) diff --git a/src/lerobot/configs/recipes/smolvla2_hirobot.yaml b/src/lerobot/configs/recipes/smolvla2_hirobot.yaml index a13cb8bf0..2586d9529 100644 --- a/src/lerobot/configs/recipes/smolvla2_hirobot.yaml +++ b/src/lerobot/configs/recipes/smolvla2_hirobot.yaml @@ -62,14 +62,14 @@ blend: ask_vqa_top: weight: 0.10 bindings: - vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)" - vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)" + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)" messages: - role: user stream: high_level if_present: vqa_query content: - - {type: image, feature: observation.images.top} + - {type: image, feature: observation.images.front} - {type: text, text: "${vqa_query}"} - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index 8be3609f3..067f91091 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -34,7 +34,6 @@ from .dataset_tools import ( remove_feature, split_dataset, ) -from .factory import make_dataset, resolve_delta_timestamps from .image_writer import safe_stop_image_writer from .io_utils import load_episodes, write_stats from .language import ( @@ -53,6 +52,19 @@ from .streaming_dataset import StreamingLeRobotDataset from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card from .video_utils import VideoEncodingManager + +def make_dataset(*args, **kwargs): + from .factory import make_dataset as _make_dataset + + return _make_dataset(*args, **kwargs) + + +def resolve_delta_timestamps(*args, **kwargs): + from .factory import resolve_delta_timestamps as _resolve_delta_timestamps + + return _resolve_delta_timestamps(*args, **kwargs) + + # NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.) # and legacy migration constants are intentionally NOT re-exported here. # Import directly: ``from lerobot.datasets.io_utils import ...`` diff --git a/src/lerobot/datasets/language.py b/src/lerobot/datasets/language.py index 8ab4e006e..61749bc70 100644 --- a/src/lerobot/datasets/language.py +++ b/src/lerobot/datasets/language.py @@ -64,8 +64,22 @@ def _json_arrow_type() -> pa.DataType: def _json_feature() -> object: - """Return the HF ``datasets`` JSON feature, falling back to a string value.""" - return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string") + """Return the HF feature used for tool-call payloads. + + Older ``datasets`` versions do not expose ``datasets.Json``. The + annotation pipeline currently emits the canonical ``say`` tool call + shape, so use that explicit struct instead of falling back to a string + that cannot cast structured parquet values. + """ + if hasattr(datasets, "Json"): + return datasets.Json() + return { + "type": datasets.Value("string"), + "function": { + "name": datasets.Value("string"), + "arguments": {"text": datasets.Value("string")}, + }, + } def language_persistent_row_arrow_type() -> pa.StructType: diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index e138a84d9..c704f5dc5 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -26,6 +26,7 @@ from .sac.configuration_sac import SACConfig as SACConfig from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig from .sarm.configuration_sarm import SARMConfig as SARMConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig +from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .utils import make_robot_action, prepare_observation_for_inference from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig @@ -49,6 +50,7 @@ __all__ = [ "SACConfig", "SARMConfig", "SmolVLAConfig", + "SmolVLA2Config", "TDMPCConfig", "VQBeTConfig", "WallXConfig", diff --git a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py index f2b771b64..bebcdd04f 100644 --- a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py +++ b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py @@ -99,85 +99,67 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): # falls back to whatever ``task`` is in the transition. return transition - message_streams: list[str | None] = list(comp.get("message_streams") or []) - target_indices: list[int] = sorted( - int(i) for i in (comp.get("target_message_indices") or []) - ) - tokenizer = self._get_tokenizer() - text_messages = [_strip_lerobot_blocks(m) for m in messages] - # Tokenize the full chat once. - full_ids = tokenizer.apply_chat_template( - text_messages, - tools=self.tools, - add_generation_prompt=False, - tokenize=True, - return_tensors=None, + if _is_batched_messages(messages): + encoded = [ + self._encode_messages( + tokenizer, + msg, + list(streams), + sorted(int(i) for i in indices), + ) + for msg, streams, indices in zip( + messages, + comp.get("message_streams") or [[] for _ in messages], + comp.get("target_message_indices") or [[] for _ in messages], + strict=True, + ) + ] + else: + encoded = [ + self._encode_messages( + tokenizer, + messages, + list(comp.get("message_streams") or []), + sorted(int(i) for i in (comp.get("target_message_indices") or [])), + ) + ] + + pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + target_length = self.max_length if self.padding == "max_length" else max( + len(ids) for ids, _, _ in encoded ) - if isinstance(full_ids, list) and full_ids and isinstance(full_ids[0], list): - full_ids = full_ids[0] + target_length = min(target_length, self.max_length) - # Build the label mask by re-rendering progressively up to each - # target message and reading off the prefix length. This is the - # robust way to get exact token boundaries: we use the same - # tokenizer, the same ``tools=`` argument, and the same chat - # template — so the prefix tokens are guaranteed to be a prefix - # of the full sequence. - labels = [-100] * len(full_ids) - for tgt in target_indices: - prefix_ids = tokenizer.apply_chat_template( - text_messages[:tgt], - tools=self.tools, - add_generation_prompt=False, - tokenize=True, - return_tensors=None, - ) - full_through_target = tokenizer.apply_chat_template( - text_messages[: tgt + 1], - tools=self.tools, - add_generation_prompt=False, - tokenize=True, - return_tensors=None, - ) - if isinstance(prefix_ids, list) and prefix_ids and isinstance(prefix_ids[0], list): - prefix_ids = prefix_ids[0] - if ( - isinstance(full_through_target, list) - and full_through_target - and isinstance(full_through_target[0], list) - ): - full_through_target = full_through_target[0] - start = len(prefix_ids) - end = min(len(full_through_target), len(full_ids)) - for pos in range(start, end): - labels[pos] = int(full_ids[pos]) + ids_batch = [] + attn_batch = [] + labels_batch = [] + predict_actions = [] + for ids, labels, predict_action in encoded: + ids = ids[:target_length] + labels = labels[:target_length] + attn = [1] * len(ids) + if len(ids) < target_length: + n_pad = target_length - len(ids) + ids = ids + [pad_id] * n_pad + labels = labels + [-100] * n_pad + attn = attn + [0] * n_pad + ids_batch.append(ids) + attn_batch.append(attn) + labels_batch.append(labels) + predict_actions.append(predict_action) - # Truncate / pad to ``max_length`` so batches collate cleanly. - # The SmolVLA pipeline downstream relies on a fixed length - # behaviour ("longest" or "max_length") — we mirror it here. - if len(full_ids) > self.max_length: - full_ids = full_ids[: self.max_length] - labels = labels[: self.max_length] - attn = [1] * len(full_ids) - if self.padding == "max_length" and len(full_ids) < self.max_length: - pad_id = ( - tokenizer.pad_token_id - if tokenizer.pad_token_id is not None - else 0 - ) - n_pad = self.max_length - len(full_ids) - full_ids = full_ids + [pad_id] * n_pad - labels = labels + [-100] * n_pad - attn = attn + [0] * n_pad + ids_t = torch.tensor(ids_batch, dtype=torch.long) + attn_t = torch.tensor(attn_batch, dtype=torch.bool) + labels_t = torch.tensor(labels_batch, dtype=torch.long) + predict_actions_t = torch.tensor(predict_actions, dtype=torch.bool) - ids_t = torch.tensor(full_ids, dtype=torch.long) - attn_t = torch.tensor(attn, dtype=torch.bool) - labels_t = torch.tensor(labels, dtype=torch.long) - predict_actions = any( - i < len(message_streams) and message_streams[i] == "low_level" - for i in target_indices - ) + if not _is_batched_messages(messages): + ids_t = ids_t.squeeze(0) + attn_t = attn_t.squeeze(0) + labels_t = labels_t.squeeze(0) + predict_actions_t = predict_actions_t.squeeze(0) new_complementary = dict(comp) # Drop the per-recipe sidecar keys; everything downstream needs @@ -194,7 +176,7 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): observation[OBS_LANGUAGE_TOKENS] = ids_t observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t new_complementary["text_labels"] = labels_t - new_complementary["predict_actions"] = torch.tensor(predict_actions, dtype=torch.bool) + new_complementary["predict_actions"] = predict_actions_t new_complementary.pop("task", None) new_transition = dict(transition) @@ -202,6 +184,53 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): new_transition[TransitionKey.OBSERVATION] = observation return new_transition + def _encode_messages( + self, + tokenizer: Any, + messages: list[dict[str, Any]], + message_streams: list[str | None], + target_indices: list[int], + ) -> tuple[list[int], list[int], bool]: + text_messages = [_strip_lerobot_blocks(m) for m in messages] + + full_ids = tokenizer.apply_chat_template( + text_messages, + tools=self.tools, + add_generation_prompt=False, + tokenize=True, + return_tensors=None, + ) + full_ids = _as_token_ids(full_ids) + + labels = [-100] * len(full_ids) + for tgt in target_indices: + prefix_ids = tokenizer.apply_chat_template( + text_messages[:tgt], + tools=self.tools, + add_generation_prompt=False, + tokenize=True, + return_tensors=None, + ) + full_through_target = tokenizer.apply_chat_template( + text_messages[: tgt + 1], + tools=self.tools, + add_generation_prompt=False, + tokenize=True, + return_tensors=None, + ) + prefix_ids = _as_token_ids(prefix_ids) + full_through_target = _as_token_ids(full_through_target) + start = len(prefix_ids) + end = min(len(full_through_target), len(full_ids)) + for pos in range(start, end): + labels[pos] = int(full_ids[pos]) + + predict_actions = any( + i < len(message_streams) and message_streams[i] == "low_level" + for i in target_indices + ) + return [int(i) for i in full_ids], labels, predict_actions + def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: @@ -247,15 +276,11 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]: continue if block.get("type") == "text": text_parts.append({"type": "text", "text": str(block.get("text", ""))}) - # If only one text block survives, flatten to a string for - # template friendliness; some chat templates choke on a single- - # element list. - if len(text_parts) == 1: - new["content"] = text_parts[0]["text"] - elif text_parts: - new["content"] = text_parts - else: - new["content"] = "" + new["content"] = text_parts or [{"type": "text", "text": ""}] + elif content is None: + new["content"] = [{"type": "text", "text": ""}] + else: + new["content"] = [{"type": "text", "text": str(content)}] if "tool_calls" in new and not new["tool_calls"]: # Drop empty tool_calls — some templates render them as a # spurious empty marker. @@ -267,5 +292,19 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]: return new +def _is_batched_messages(messages: Any) -> bool: + return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) + + +def _as_token_ids(value: Any) -> list[int]: + if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()): + value = value["input_ids"] + if hasattr(value, "tolist"): + value = value.tolist() + if isinstance(value, list) and value and isinstance(value[0], list): + value = value[0] + return [int(i) for i in value] + + # Re-export for tests / introspection strip_lerobot_blocks = _strip_lerobot_blocks diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index 258db6f68..9a54f35b0 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -60,7 +60,7 @@ class SmolVLA2Policy(SmolVLAPolicy): config_class = SmolVLA2Config name = "smolvla2" - def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + def __init__(self, config: SmolVLA2Config, **kwargs): if not isinstance(config, SmolVLA2Config): config = SmolVLA2Config( **{ @@ -69,7 +69,7 @@ class SmolVLA2Policy(SmolVLAPolicy): if hasattr(config, f.name) } ) - super().__init__(config, dataset_stats=dataset_stats) + super().__init__(config, **kwargs) if config.unfreeze_lm_head and config.text_loss_weight > 0: self._unfreeze_lm_head() @@ -200,7 +200,7 @@ class SmolVLA2Policy(SmolVLAPolicy): past_key_values=None, inputs_embeds=[prefix_embs, None], use_cache=False, - fill_kv_cache=False, + fill_kv_cache=True, ) prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair if prefix_out is None: @@ -228,8 +228,8 @@ class SmolVLA2Policy(SmolVLAPolicy): f"num_state={num_state})." ) - lang_hidden = prefix_out[:, lang_start:lang_end] vlm = self.model.vlm_with_expert.vlm + lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab) if text_labels.shape[1] != num_lang: @@ -244,12 +244,14 @@ class SmolVLA2Policy(SmolVLAPolicy): # for the same convention. shift_logits = logits[:, :-1, :].contiguous() shift_labels = text_labels[:, 1:].contiguous().long() + valid_labels = shift_labels != -100 loss = F.cross_entropy( shift_logits.reshape(-1, shift_logits.shape[-1]), shift_labels.reshape(-1), ignore_index=-100, + reduction="sum", ) - return loss + return loss / valid_labels.sum().clamp(min=1) # ------------------------------------------------------------------ # Inference: text generation diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 669c68a0a..804a3aaf0 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -175,9 +175,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep): if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0: complementary_data["task_index"] = task_index_value.unsqueeze(0) - complementary_data.pop("language_persistent", None) - complementary_data.pop("language_events", None) - if "messages" in complementary_data: messages = complementary_data["messages"] if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)): diff --git a/src/lerobot/processor/render_messages_processor.py b/src/lerobot/processor/render_messages_processor.py index 7d88fab73..4c9a25c4c 100644 --- a/src/lerobot/processor/render_messages_processor.py +++ b/src/lerobot/processor/render_messages_processor.py @@ -51,6 +51,9 @@ class RenderMessagesStep(ProcessorStep): if not persistent and not events: return transition + if _is_batched_language(persistent) or _is_batched_language(events): + return self._call_batch(transition, complementary_data, persistent, events) + timestamp = complementary_data.get("timestamp") if timestamp is None: raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.") @@ -69,13 +72,64 @@ class RenderMessagesStep(ProcessorStep): return None new_transition = transition.copy() - new_complementary_data = dict(complementary_data) + new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) new_complementary_data.pop(LANGUAGE_PERSISTENT, None) new_complementary_data.pop(LANGUAGE_EVENTS, None) new_complementary_data.update(rendered) new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data return new_transition + def _call_batch( + self, + transition: EnvTransition, + complementary_data: dict[str, Any], + persistent_batch: list, + events_batch: list, + ) -> EnvTransition | None: + timestamp = complementary_data.get("timestamp") + if timestamp is None: + raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.") + + batch_size = max(len(persistent_batch), len(events_batch)) + messages: list[list[dict[str, Any]]] = [] + message_streams: list[list[str | None]] = [] + target_message_indices: list[list[int]] = [] + keep_indices: list[int] = [] + + for i in range(batch_size): + rendered = render_sample( + recipe=self.recipe, + persistent=persistent_batch[i] if i < len(persistent_batch) else [], + events=events_batch[i] if i < len(events_batch) else [], + t=_batch_value(timestamp, i), + sample_idx=int(_batch_value(complementary_data.get("index", 0), i)), + task=_batch_value(complementary_data.get("task"), i), + dataset_ctx=self.dataset_ctx, + ) + if rendered is None: + continue + keep_indices.append(i) + messages.append(rendered["messages"]) + message_streams.append(rendered["message_streams"]) + target_message_indices.append(rendered["target_message_indices"]) + + if not messages: + return None + + new_transition = ( + _select_batch_indices(transition, keep_indices) + if len(keep_indices) != batch_size + else transition.copy() + ) + new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) + new_complementary_data.pop(LANGUAGE_PERSISTENT, None) + new_complementary_data.pop(LANGUAGE_EVENTS, None) + new_complementary_data["messages"] = messages + new_complementary_data["message_streams"] = message_streams + new_complementary_data["target_message_indices"] = target_message_indices + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data + return new_transition + def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: @@ -90,3 +144,37 @@ def _scalar(value: Any) -> float | int: if isinstance(value, list) and len(value) == 1: return _scalar(value[0]) return value + + +def _is_batched_language(value: Any) -> bool: + return isinstance(value, list) and bool(value) and isinstance(value[0], list) + + +def _batch_value(value: Any, index: int) -> Any: + if value is None: + return None + if isinstance(value, list): + return value[index] + if hasattr(value, "ndim") and getattr(value, "ndim") > 0: + return _scalar(value[index]) + return _scalar(value) + + +def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition: + selected = transition.copy() + for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA): + data = selected.get(key) + if isinstance(data, dict): + selected[key] = {k: _select_value(v, indices) for k, v in data.items()} + action = selected.get(TransitionKey.ACTION) + if action is not None: + selected[TransitionKey.ACTION] = _select_value(action, indices) + return selected + + +def _select_value(value: Any, indices: list[int]) -> Any: + if isinstance(value, list) and len(value) >= len(indices): + return [value[i] for i in indices] + if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0: + return value.index_select(0, value.new_tensor(indices).long()) + return value diff --git a/src/lerobot/utils/collate.py b/src/lerobot/utils/collate.py index ca32430cd..1bfca32f3 100644 --- a/src/lerobot/utils/collate.py +++ b/src/lerobot/utils/collate.py @@ -22,7 +22,7 @@ from torch.utils.data._utils.collate import default_collate from lerobot.datasets.language import LANGUAGE_COLUMNS -_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"} +_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS} def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None: