fix(smolvla2): train on rendered language batches

Keep annotated language columns through collation, render batched recipe samples, and make SmolVLA2 text loss robust enough for distributed training on the steerable dataset.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-05 08:55:56 +00:00
parent 5f7c6ba61d
commit a1b8134ef1
9 changed files with 253 additions and 99 deletions
@@ -62,14 +62,14 @@ blend:
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.top}
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+13 -1
View File
@@ -34,7 +34,6 @@ from .dataset_tools import (
remove_feature,
split_dataset,
)
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
@@ -53,6 +52,19 @@ from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
def make_dataset(*args, **kwargs):
from .factory import make_dataset as _make_dataset
return _make_dataset(*args, **kwargs)
def resolve_delta_timestamps(*args, **kwargs):
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
return _resolve_delta_timestamps(*args, **kwargs)
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
# and legacy migration constants are intentionally NOT re-exported here.
# Import directly: ``from lerobot.datasets.io_utils import ...``
+16 -2
View File
@@ -64,8 +64,22 @@ def _json_arrow_type() -> pa.DataType:
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
"""Return the HF feature used for tool-call payloads.
Older ``datasets`` versions do not expose ``datasets.Json``. The
annotation pipeline currently emits the canonical ``say`` tool call
shape, so use that explicit struct instead of falling back to a string
that cannot cast structured parquet values.
"""
if hasattr(datasets, "Json"):
return datasets.Json()
return {
"type": datasets.Value("string"),
"function": {
"name": datasets.Value("string"),
"arguments": {"text": datasets.Value("string")},
},
}
def language_persistent_row_arrow_type() -> pa.StructType:
+2
View File
@@ -26,6 +26,7 @@ from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
@@ -49,6 +50,7 @@ __all__ = [
"SACConfig",
"SARMConfig",
"SmolVLAConfig",
"SmolVLA2Config",
"TDMPCConfig",
"VQBeTConfig",
"WallXConfig",
@@ -99,85 +99,67 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
# falls back to whatever ``task`` is in the transition.
return transition
message_streams: list[str | None] = list(comp.get("message_streams") or [])
target_indices: list[int] = sorted(
int(i) for i in (comp.get("target_message_indices") or [])
)
tokenizer = self._get_tokenizer()
text_messages = [_strip_lerobot_blocks(m) for m in messages]
# Tokenize the full chat once.
full_ids = tokenizer.apply_chat_template(
text_messages,
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
if _is_batched_messages(messages):
encoded = [
self._encode_messages(
tokenizer,
msg,
list(streams),
sorted(int(i) for i in indices),
)
for msg, streams, indices in zip(
messages,
comp.get("message_streams") or [[] for _ in messages],
comp.get("target_message_indices") or [[] for _ in messages],
strict=True,
)
]
else:
encoded = [
self._encode_messages(
tokenizer,
messages,
list(comp.get("message_streams") or []),
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
)
]
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
target_length = self.max_length if self.padding == "max_length" else max(
len(ids) for ids, _, _ in encoded
)
if isinstance(full_ids, list) and full_ids and isinstance(full_ids[0], list):
full_ids = full_ids[0]
target_length = min(target_length, self.max_length)
# Build the label mask by re-rendering progressively up to each
# target message and reading off the prefix length. This is the
# robust way to get exact token boundaries: we use the same
# tokenizer, the same ``tools=`` argument, and the same chat
# template — so the prefix tokens are guaranteed to be a prefix
# of the full sequence.
labels = [-100] * len(full_ids)
for tgt in target_indices:
prefix_ids = tokenizer.apply_chat_template(
text_messages[:tgt],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
full_through_target = tokenizer.apply_chat_template(
text_messages[: tgt + 1],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
if isinstance(prefix_ids, list) and prefix_ids and isinstance(prefix_ids[0], list):
prefix_ids = prefix_ids[0]
if (
isinstance(full_through_target, list)
and full_through_target
and isinstance(full_through_target[0], list)
):
full_through_target = full_through_target[0]
start = len(prefix_ids)
end = min(len(full_through_target), len(full_ids))
for pos in range(start, end):
labels[pos] = int(full_ids[pos])
ids_batch = []
attn_batch = []
labels_batch = []
predict_actions = []
for ids, labels, predict_action in encoded:
ids = ids[:target_length]
labels = labels[:target_length]
attn = [1] * len(ids)
if len(ids) < target_length:
n_pad = target_length - len(ids)
ids = ids + [pad_id] * n_pad
labels = labels + [-100] * n_pad
attn = attn + [0] * n_pad
ids_batch.append(ids)
attn_batch.append(attn)
labels_batch.append(labels)
predict_actions.append(predict_action)
# Truncate / pad to ``max_length`` so batches collate cleanly.
# The SmolVLA pipeline downstream relies on a fixed length
# behaviour ("longest" or "max_length") — we mirror it here.
if len(full_ids) > self.max_length:
full_ids = full_ids[: self.max_length]
labels = labels[: self.max_length]
attn = [1] * len(full_ids)
if self.padding == "max_length" and len(full_ids) < self.max_length:
pad_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else 0
)
n_pad = self.max_length - len(full_ids)
full_ids = full_ids + [pad_id] * n_pad
labels = labels + [-100] * n_pad
attn = attn + [0] * n_pad
ids_t = torch.tensor(ids_batch, dtype=torch.long)
attn_t = torch.tensor(attn_batch, dtype=torch.bool)
labels_t = torch.tensor(labels_batch, dtype=torch.long)
predict_actions_t = torch.tensor(predict_actions, dtype=torch.bool)
ids_t = torch.tensor(full_ids, dtype=torch.long)
attn_t = torch.tensor(attn, dtype=torch.bool)
labels_t = torch.tensor(labels, dtype=torch.long)
predict_actions = any(
i < len(message_streams) and message_streams[i] == "low_level"
for i in target_indices
)
if not _is_batched_messages(messages):
ids_t = ids_t.squeeze(0)
attn_t = attn_t.squeeze(0)
labels_t = labels_t.squeeze(0)
predict_actions_t = predict_actions_t.squeeze(0)
new_complementary = dict(comp)
# Drop the per-recipe sidecar keys; everything downstream needs
@@ -194,7 +176,7 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
observation[OBS_LANGUAGE_TOKENS] = ids_t
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
new_complementary["text_labels"] = labels_t
new_complementary["predict_actions"] = torch.tensor(predict_actions, dtype=torch.bool)
new_complementary["predict_actions"] = predict_actions_t
new_complementary.pop("task", None)
new_transition = dict(transition)
@@ -202,6 +184,53 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
def _encode_messages(
self,
tokenizer: Any,
messages: list[dict[str, Any]],
message_streams: list[str | None],
target_indices: list[int],
) -> tuple[list[int], list[int], bool]:
text_messages = [_strip_lerobot_blocks(m) for m in messages]
full_ids = tokenizer.apply_chat_template(
text_messages,
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
full_ids = _as_token_ids(full_ids)
labels = [-100] * len(full_ids)
for tgt in target_indices:
prefix_ids = tokenizer.apply_chat_template(
text_messages[:tgt],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
full_through_target = tokenizer.apply_chat_template(
text_messages[: tgt + 1],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
prefix_ids = _as_token_ids(prefix_ids)
full_through_target = _as_token_ids(full_through_target)
start = len(prefix_ids)
end = min(len(full_through_target), len(full_ids))
for pos in range(start, end):
labels[pos] = int(full_ids[pos])
predict_actions = any(
i < len(message_streams) and message_streams[i] == "low_level"
for i in target_indices
)
return [int(i) for i in full_ids], labels, predict_actions
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
@@ -247,15 +276,11 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
continue
if block.get("type") == "text":
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
# If only one text block survives, flatten to a string for
# template friendliness; some chat templates choke on a single-
# element list.
if len(text_parts) == 1:
new["content"] = text_parts[0]["text"]
elif text_parts:
new["content"] = text_parts
else:
new["content"] = ""
new["content"] = text_parts or [{"type": "text", "text": ""}]
elif content is None:
new["content"] = [{"type": "text", "text": ""}]
else:
new["content"] = [{"type": "text", "text": str(content)}]
if "tool_calls" in new and not new["tool_calls"]:
# Drop empty tool_calls — some templates render them as a
# spurious empty marker.
@@ -267,5 +292,19 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
return new
def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _as_token_ids(value: Any) -> list[int]:
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
value = value["input_ids"]
if hasattr(value, "tolist"):
value = value.tolist()
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
return [int(i) for i in value]
# Re-export for tests / introspection
strip_lerobot_blocks = _strip_lerobot_blocks
@@ -60,7 +60,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
config_class = SmolVLA2Config
name = "smolvla2"
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
def __init__(self, config: SmolVLA2Config, **kwargs):
if not isinstance(config, SmolVLA2Config):
config = SmolVLA2Config(
**{
@@ -69,7 +69,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
if hasattr(config, f.name)
}
)
super().__init__(config, dataset_stats=dataset_stats)
super().__init__(config, **kwargs)
if config.unfreeze_lm_head and config.text_loss_weight > 0:
self._unfreeze_lm_head()
@@ -200,7 +200,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
fill_kv_cache=False,
fill_kv_cache=True,
)
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
if prefix_out is None:
@@ -228,8 +228,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
f"num_state={num_state})."
)
lang_hidden = prefix_out[:, lang_start:lang_end]
vlm = self.model.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
if text_labels.shape[1] != num_lang:
@@ -244,12 +244,14 @@ class SmolVLA2Policy(SmolVLAPolicy):
# for the same convention.
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
valid_labels = shift_labels != -100
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
return loss
return loss / valid_labels.sum().clamp(min=1)
# ------------------------------------------------------------------
# Inference: text generation
-3
View File
@@ -175,9 +175,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
@@ -51,6 +51,9 @@ class RenderMessagesStep(ProcessorStep):
if not persistent and not events:
return transition
if _is_batched_language(persistent) or _is_batched_language(events):
return self._call_batch(transition, complementary_data, persistent, events)
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
@@ -69,13 +72,64 @@ class RenderMessagesStep(ProcessorStep):
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def _call_batch(
self,
transition: EnvTransition,
complementary_data: dict[str, Any],
persistent_batch: list,
events_batch: list,
) -> EnvTransition | None:
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
batch_size = max(len(persistent_batch), len(events_batch))
messages: list[list[dict[str, Any]]] = []
message_streams: list[list[str | None]] = []
target_message_indices: list[list[int]] = []
keep_indices: list[int] = []
for i in range(batch_size):
rendered = render_sample(
recipe=self.recipe,
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
events=events_batch[i] if i < len(events_batch) else [],
t=_batch_value(timestamp, i),
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
task=_batch_value(complementary_data.get("task"), i),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
continue
keep_indices.append(i)
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
if not messages:
return None
new_transition = (
_select_batch_indices(transition, keep_indices)
if len(keep_indices) != batch_size
else transition.copy()
)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data["messages"] = messages
new_complementary_data["message_streams"] = message_streams
new_complementary_data["target_message_indices"] = target_message_indices
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
@@ -90,3 +144,37 @@ def _scalar(value: Any) -> float | int:
if isinstance(value, list) and len(value) == 1:
return _scalar(value[0])
return value
def _is_batched_language(value: Any) -> bool:
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
def _batch_value(value: Any, index: int) -> Any:
if value is None:
return None
if isinstance(value, list):
return value[index]
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
return _scalar(value[index])
return _scalar(value)
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
selected = transition.copy()
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
data = selected.get(key)
if isinstance(data, dict):
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
action = selected.get(TransitionKey.ACTION)
if action is not None:
selected[TransitionKey.ACTION] = _select_value(action, indices)
return selected
def _select_value(value: Any, indices: list[int]) -> Any:
if isinstance(value, list) and len(value) >= len(indices):
return [value[i] for i in indices]
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
return value.index_select(0, value.new_tensor(indices).long())
return value
+1 -1
View File
@@ -22,7 +22,7 @@ from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None: