Compare commits

..

341 Commits

Author SHA1 Message Date
pepijn223 3427499212 feat(pi052): condition low-level prompt on state + fix eval slowdown
- Inject discretized proprioceptive state (256 bins, pi05 format) into
  low-level action-conditioning prompts in both training
  (PI052TextTokenizerStep) and eval (_with_low_level_subtask_prompt),
  matching the recipe's documented "[images, subtask, state]" intent.
  Higher-level subtask/memory text streams stay state-free.
- Cache the loc-token tokenizer (_get_loc_tokenizer) instead of reloading
  it from disk on every _build_text_batch/select_message call (it ran
  twice per env per replan and dominated eval runtime).
- Add a KV cache to select_message decode (bit-identical output to the
  recompute path) to avoid O(n^2) generation.

Net: pi052 eval ~2.9 s/it -> ~0.1 s/it (~25x).
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-14 13:57:55 +02:00
Pepijn c5965d4971 Merge branch 'main' into feat/smolvla-on-steerable 2026-06-08 11:02:54 +02:00
pepijn223 470fdd195d fix(ema): default EMA decay to 0.99
Matches openpi's top-level default (ema_decay=0.99, ~last 100 steps).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-05 16:10:00 +02:00
pepijn223 384feca91a fix(ema): default EMAConfig.enable to False (opt-in)
EMA was on by default, so every training run on the branch (incl. VLA-JEPA
and other non-flow-matching policies) created a full fp32 shadow copy. EMA
only benefits flow-matching/diffusion policies (pi0/pi05/pi052). Make it
opt-in via --ema.enable=true; the pi05/pi052 recipes already pass that flag.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-05 16:09:08 +02:00
pepijn223 7b35af6eca Merge remote-tracking branch 'origin/main' into feat/smolvla-on-steerable
Co-authored-by: Cursor <cursoragent@cursor.com>

# Conflicts:
#	uv.lock
2026-06-05 14:38:47 +02:00
pepijn223 aca02ff24c fix(robocasa): align env state/action order to openpi/robocasa convention
LeRobot's RoboCasaEnv used a divergent flat state/action layout vs the
robocasa package (robocasa.utils.env_utils.convert_action) and the openpi
robocasa pipeline. This scrambles I/O when using openpi-convention checkpoints
(e.g. the JAX->PyTorch->LeRobot converted pi05 robocasa model: CloseFridge
20% -> 60% once both orders match openpi).

- convert_action: ee_pos(3)+ee_rot(3)+gripper(1)+base_motion(4)+control_mode(1)
- observation.state: ee_pos_rel(3)+ee_rot_rel(4)+base_pos(3)+base_rot(4)+gripper(2)

Matches openpi examples/robocasa/main.py + RobocasaInputs ordering.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-05 13:47:43 +02:00
pepijn223 de7ba67556 style: drop decorative === comment banners from pi052 split
Replace the === separator banners (against repo style) with plain comments.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 20:21:10 +02:00
pepijn223 c020c0d053 refactor(pi052): split pi05_backbone into pi_gemma + modeling_pi052
Eliminate the standalone pi052/pi05_backbone.py by distributing its contents:
- Generic dual-expert transformer machinery -> lerobot/policies/pi_gemma.py
  (sdpa_attention_forward, compute_layer_complete, PaliGemmaWithExpertModel,
  get_gemma_config; the openpi width/depth config is renamed GemmaConfig ->
  GemmaVariantConfig to avoid clashing with transformers' GemmaConfig). These
  sit next to the existing PiGemma layer code they already depend on.
- pi052-specific model + helpers -> pi052/modeling_pi052.py (PI05Pytorch,
  ActionSelectKwargs, make_att_2d_masks, pad_vector, resize_with_pad_torch,
  create_sinusoidal_pos_embedding, sample_beta, get_safe_dtype).

DEFAULT_IMAGE_SIZE is duplicated as a plain constant in pi_gemma to avoid a
pi_gemma -> pi05 import cycle. Additive to pi_gemma; pi0/pi05 unaffected.
Verified bit-exact on pepijn223/pi052_robocasa_full (embed/predict/forward
identical) and all 34 pi052 tests pass.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 20:18:18 +02:00
pepijn223 4cbd91a04e chore: drop one-off bench/build/train scripts from the PR
Remove development-only tooling that doesn't belong in the PR:
- examples/benchmark/* (pi052 step/kernel benchmark slurm + harness)
- examples/port_datasets/slurm_build_robocasa_composite_seen.py and
  src/lerobot/scripts/build_robocasa_composite_seen.py (composite_seen
  dataset build scripts)
- scripts/build_episode_filter.py, scripts/build_robocasa_smoke.sh,
  scripts/train_pi052_human300_exclude_unannotated.sh

None are imported by the library, tests, or entry points.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 20:05:25 +02:00
pepijn223 afe30630cc test(pi052): repair stale-name CE tests for fused linear CE
_fast_ce/_shifted_ce were renamed to _fast_lin_ce/_shifted_lin_ce and changed
from logits-based to Liger fused-linear-CE (hidden @ lm_head_weightᵀ). Update
the tests via thin adapters that pass an identity lm_head_weight (so the
computed logits equal the provided ones), run on CUDA (Liger is GPU-only) and
skip otherwise, and loosen the allclose tolerance to absorb GPU-vs-CPU float
noise on the tiny losses.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 20:03:18 +02:00
pepijn223 a594ad7969 refactor(pi052): self-contained policy; revert pi0/pi05 to upstream main
The smolvla branch had modified the shared pi0/pi05 modeling + pi05 config to
support pi052 (SDPA attention, layernorm/lm_head handling, optimizer
foreach/fused/lm_head_lr_scale, embedding scaling). Decouple pi052 instead:

- Vendor the PI0.5 backbone (PaliGemmaWithExpertModel, PI05Pytorch, helpers)
  into pi052/pi05_backbone.py (verbatim copy, no PI05Policy).
- Flatten PI052Policy to subclass PreTrainedPolicy directly (no longer
  PI05Policy); inline the needed PI05Policy methods.
- Restore optimizer_foreach/fused + get_optimizer_preset on PI052Config.
- Revert pi0, pi0_fast, pi05 modeling and configuration_pi05 to origin/main
  (byte-identical), so the shared policies carry no smolvla modifications.

Behavior verified bit-exact on pepijn223/pi052_robocasa_full: embed_language_
tokens, predict_action_chunk, and the fused flow+text+FAST training loss are
identical before/after (max_abs_diff=0). pi052 tests pass (pre-existing
stale-name collection errors unchanged).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 19:59:27 +02:00
pepijn223 8292548f0d fix(pi052): stop double-scaling FAST/text token embeddings
embed_language_tokens already applies Gemma's sqrt(hidden) normalizer
(GemmaTextScaledWordEmbedding, transformers >=5.4.0). pi052 multiplied FAST
action-token and autoregressive subtask-text embeddings by sqrt(emb_dim) on
top of that, double-scaling them (~2048x). Remove the manual scaling so FAST
and text tokens are single-scaled, consistent with the pi05 fix and OpenPI.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 18:31:41 +02:00
pepijn223 77cc35b932 fix(pi0,pi05,pi0_fast): stop double-scaling text embeddings
transformers >=5.4.0 (PR #44432) makes Gemma's embed_tokens a
GemmaTextScaledWordEmbedding that already multiplies token embeddings by
sqrt(hidden_size). The manual `* sqrt(embed_dim)` applied on top therefore
double-scaled text (~2048x instead of ~45x), breaking VLM alignment for
models trained/run on stock transformers. Remove the manual scaling and rely
on embed_tokens' internal normalizer (matches main #3603). Image features
stay raw (un-normalized), as before.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 18:22:34 +02:00
pepijn223 f0757fc707 fix(pi0,pi0_fast): scale text embeddings by sqrt(embed_dim) to match OpenPI
OpenPI (pi0 and pi0-FAST) multiplies language token embeddings by
sqrt(embed_dim) — the Gemma embedder normalizer — before the transformer.
LeRobot pi0/pi0_fast omitted it, leaving text tokens ~45x under-scaled
relative to the residual stream (same class of bug as the pi05 image
scaling). pi0: applied in embed_prefix's lang_embed_func. pi0_fast:
applied inside embed_language_tokens so prompt, FAST action tokens, and
autoregressive next-token embeds are all scaled consistently.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 18:14:27 +02:00
pepijn223 a48d4e32a1 fix(pi05): don't scale image features by sqrt(hidden_size)
lerobot/pi05_base was trained in the OpenPI/big_vision regime where image
(soft) tokens are NOT multiplied by the Gemma embedder normalizer
(sqrt(hidden_size)) — only text tokens are. Scaling image features here
over-scaled them ~45x, breaking the pretrained vision-language alignment
and yielding ~0% closed-loop success on RoboCasa across all pi05 runs.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 17:20:34 +02:00
Pepijn 9596e3d53f Merge remote-tracking branch 'origin/feat/smolvla-on-steerable' into feat/smolvla-on-steerable 2026-06-04 17:14:33 +02:00
Pepijn 0a6a799317 Merge feat/language-annotation-pipeline into feat/smolvla-on-steerable
Bring the authoritative annotation pipeline from the annotation branch.
The annotation surface is forced to EXACTLY match feat/language-annotation-
pipeline (the annotation branch is the source of truth for annotation
code), which also removes smolvla's stale copies:
  - deleted: steerable_pipeline/vocabulary.py, tests/annotations/test_
    vocabulary.py, prompts/module_0_vocabulary.txt, module_1_action_record
    .txt, module_3_vqa.txt, module_1_plan.txt, and the old module_* prompt
    names (now plan_*/interjections_*/vqa.txt).
  - synced: all of src/lerobot/annotations/, lerobot_annotate.py,
    examples/annotations/, tests/annotations/, datasets/language.py,
    tests/datasets/test_language.py, docs/annotation_pipeline.mdx.

Non-annotation conflicts resolved by union (keeping both branches' intent):
  - pyproject.toml: keep smolvla's pi extra (+sentencepiece) and add the
    molmoact2 extra from main.
  - policies/factory.py: keep both dataset_repo_id (pi052 FAST tokenizer)
    and dataset_meta (both are referenced); union the policy-type docstring.
  - scripts/lerobot_train.py: keep smolvla's pi052 / use_relative_actions
    processor-rebuild block.
  - uv.lock: regenerated from the merged pyproject.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 17:13:36 +02:00
pepijn e660a51e78 pi052(debug): drop misleading inference/parity dump from text preds
The first-token parity check re-tokenized the decoded (stripped) inference
string, so the leading-space SentencePiece variant always mismatched the
training argmax — a false "DIVERGED" alarm. Remove the autoregressive
inference print and parity comparison (and the now-dead per-sample
select_message generation), keeping only the prompt, ground-truth target,
and teacher-forced argmax accuracy.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 13:32:44 +00:00
Pepijn cdd94a703f annotate(config): tighten field comments to one line each
Collapse the remaining multi-line field comments / docstrings in config.py
to single lines (or two where a knob genuinely needs it), keeping the
essential rationale. Comments only — no field or behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 15:12:31 +02:00
Pepijn cd59c8b312 annotate: remove the action_record style/feature entirely
Drop the optional structured per-subtask action records — not a feature
we want to ship.

  * language.py: remove 'action_record' from CORE_STYLES + PERSISTENT_STYLES
    (and the matching assertion in tests/datasets/test_language.py).
  * config.py: delete ActionRecordsConfig (verb/grasp vocabularies,
    frames_per_subtask, emit_record_row) and the PlanConfig.action_records
    field.
  * plan_subtasks_memory.py: delete _extract_action_record and the
    run_episode block that emitted style='action_record' rows; drop the
    now-unused json / to_image_blocks imports.
  * remove the plan_action_record.txt prompt.
  * run_hf_job.py: drop the action_records comment.

Verified: 40 tests pass; pre-commit (ruff, mypy, bandit) clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 14:40:34 +02:00
Pepijn 99baae012f annotate(config): further compact field comments
Tighten the remaining multi-line comment blocks in config.py (derive_task,
frames/window, describe_first, action-record/vqa/vlm fields, video_backend,
repo ids, executor) to 1-3 lines each. Also fix a stale path typo
('examples/annotation' -> the docstring now just says HF Jobs). Comments
only — no field or behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 14:36:02 +02:00
Pepijn 973318ef65 annotate: dedup task_aug + row-normalization; docs module on/off table
Two behavior-preserving simplifications:
  * plan_subtasks_memory.run_episode: the task_aug 'axes' and free-form
    branches built identical deduped rows via copy-pasted seen/append
    loops. Collapse to one branch that picks the variant source, then a
    shared _task_aug_rows() helper does the dedup + row build (-~25 LOC).
  * writer: _normalize_persistent_row / _normalize_event_row shared the
    same camera-validate + struct construction. Extract _normalize_row(),
    keeping the exact key order (the parquet struct schema is inferred
    from insertion order, so timestamp must stay between style and camera).

docs: 'Which modules run' is now a table giving each module's on/off flag
(--plan.enabled / --interjections.enabled / --vqa.enabled) and what it
turns off.

Verified: 40 tests pass (incl. test_writer struct round-trip); pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 14:18:36 +02:00
Pepijn 7471a6b1ed annotate: compress conftest + pyproject comments (fix stale backend note)
The pyproject annotations-extra comment still described the removed
vllm/transformers in-process backends ('vllm preferred ... transformers
fallback', '_make_vllm_client'); rewrite it for the openai-only reality
and trim it. Also condense the conftest lazy-import NOTE. Comments only.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 14:12:04 +02:00
Pepijn 20c7a12dd5 annotate: remove dead code, document CLI options, compact config
Dead code (defined but never referenced anywhere in src/tests/examples):
  * reader.py: keyframe_indices, episode_frame_timestamps, lookup_data_path,
    and the now-orphaned gather_data_paths + episode_offsets_per_path
    (lookup_data_path was their only caller).
  * staging.py: iter_staged_episodes.
  * writer.py: normalize_rows_for_writer.
  * config.py VlmConfig: json_mode, batch_size, tensor_parallel_size,
    gpu_memory_utilization, trust_remote_code — consumed only by the
    in-process vllm/transformers backends that were removed; the openai
    auto-serve path carries those vLLM flags via serve_command instead.
    Kept max_model_len (still used as the serve-command default).
  * config.py TaskAugAxesConfig.total property.

Docs: new 'Key options' section in annotation_pipeline.mdx — grouped
tables (dataset in/out, module toggles, --vlm.*, --plan.*, interjections
+ vqa) describing the flags users actually reach for, with defaults.

config.py: compact the verbose field comments + ActionRecordsConfig /
TaskAugAxesConfig docstrings; fix two stale 'verify' references (the
verify pass was removed — it's describe -> segment now) and the stale
'renders record back to subtask text' note (that path was removed).
vlm_client docstring no longer mentions the removed json_mode field.

Verified: tests/annotations + tests/datasets/test_language +
tests/scripts/test_lerobot_annotate (40 passed); pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 14:05:46 +02:00
Pepijn dbe02f0c4f annotate(plan): condense verbose comments + docstrings
Trim the long inline comment blocks (effective_task / task_aug, action
records, plan-boundary rows, plan-update span closing, windowed +
coverage-stitch sections) and the _generate_plan / run_plan_updates
docstrings to a few lines each. No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 13:52:24 +02:00
Pepijn 56cbb5f9ec annotate(example): trim run_hf_job comments to one line each
Same flags and rationale, condensed — each plan-module flag now has a
short one/two-line comment instead of a paragraph.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 13:48:55 +02:00
Pepijn 2af2402a0c docs(annotate): cleaner architecture diagram layout
Top-down flow (read episodes → 3 modules fan out → validator → writer →
parquet) with aligned boxes, instead of the cramped bordered version.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 11:59:31 +02:00
Pepijn 7bec991cdf docs(annotate): friendlier rewrite + architecture diagram; drop reproducibility section
Rewrite annotation_pipeline.mdx in plainer, easier-to-read language
(shorter sentences, active voice, a plain-text intro), add an ASCII
'How it fits together' architecture diagram, and remove the
'Reproducibility via seed and prompt hashes' section. Content/links are
preserved; only wording and structure change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 11:48:59 +02:00
Pepijn c6f682b3f4 annotate docs: install lerobot from main (post-merge wording)
The example already pins '@main'; update the doc step and the script
docstring from 'the branch under test' to 'lerobot (from main)' now that
the pipeline is merging to main.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-04 11:45:38 +02:00
Pepijn eba3ab3741 annotate: address review feedback — bug fixes, docs/code drift, naming, cleanup
Bugs
  * validator: don't re-raise on unknown style. The second column_for_style
    lookup (used to route persistent vs event) now sits in try/except so an
    unknown style is recorded by _check_column_routing and skipped instead
    of crashing the whole validation pass.
  * general_vqa._target_cameras: when restrict_to_default_camera is set but
    the configured camera_key isn't one the provider exposes, warn and fall
    back to all cameras instead of returning a phantom key that KeyErrors
    deep in frame decode.
  * interjections: clamp interjection timestamps to frame_timestamps[0]
    rather than a hardcoded 0.0 (datasets can start at non-zero t).

Docs / code drift
  * annotation_pipeline.mdx: drop the phantom 'vocabulary discovery / phase
    0 / --vocabulary.* / canonical_vocabulary.json' section (none of it
    exists); describe the real describe->segment + coverage-stitch flow.
    Soften the src/lerobot/tools/ + TOOL_REGISTRY reference to 'not part of
    this PR' (matches tools.mdx, which already marks the runtime layer as
    not-yet-implemented). Fix the --push_to_hub/--new_repo_id wording. Note
    the default is now a single h200. Add a 'Contributing new modules'
    section inviting module / prompt / quality contributions.
  * executor docstring: six phases, no phantom phase 0.

run_hf_job.py
  * add the Apache 2.0 license header (was flagged repeatedly).
  * default to a single GPU: flavor=h200, parallel_servers=1, num_gpus=1
    (scale to h200x4 noted in the docstring).
  * pin the install to @main instead of the feature branch (won't break
    after merge).

Naming / cleanup
  * rename dest_repo_id -> new_repo_id across config / script / example /
    test to match the LeRobot dataset edit tools.
  * rename prompt templates module_N_*.txt -> descriptive (plan_*,
    interjections_*, vqa.txt) and update every load_prompt() call.
  * remove dead _messages_to_prompt (used only by the removed in-process
    backends).
  * declare _warned_decode_fail (frames) and _warned_no_camera (vqa) as
    real init=False dataclass fields instead of getattr monkey-patches.
  * scope bandit B607 to the two ffmpeg subprocess.run sites via
    '# nosec B607' and drop it from the global skip list.

Tests
  * fix stale canned-VLM markers ('ONE realistic interruption' ->
    'compact interjection', 'Update the memory' -> 'compressed semantic
    memory') and drop the dead 'concise hierarchical PLAN' plan responders
    (plan generation is deterministic now) in run_e2e_smoke,
    test_pipeline_recipe_render, test_modules.
  * run_e2e_smoke now asserts interjection + speech rows are produced so a
    stale marker can't silently pass again.
  * drop remaining 'PR 1' / 'PR 2' references from test comments / names.

Verified: tests/annotations + tests/datasets/test_language +
tests/scripts/test_lerobot_annotate (31 passed); make-style E2E smoke
(interjections=1 speech_atoms=2); pre-commit (ruff, mypy, bandit,
prettier) clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 18:30:46 +02:00
Pepijn 3a24e426df language: register action_record in CORE_STYLES so STYLE_REGISTRY contains it
action_record is in PERSISTENT_STYLES but was missing from CORE_STYLES,
so STYLE_REGISTRY (= CORE_STYLES | EXTENDED_STYLES) didn't contain it and
the PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY invariant in
test_style_registry_routes_columns failed. Add it to CORE_STYLES so the
registry, the persistent-set, and column_for_style() stay consistent.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 16:38:06 +02:00
Pepijn b9a0187335 annotate: drop local in-process VLM backends — HF Jobs (openai) only for now
The shipped workflow is Hugging Face Jobs (examples/annotations/run_hf_
job.py): it serves the model with vLLM in the vllm/vllm-openai image and
the pipeline talks to it over the OpenAI-compatible API. The in-process
vllm / transformers local backends added surface (and the vllm
one pinned an old torch) without being part of that path, so they're
removed for now.

  * vlm_client.make_vlm_client: keep only backend='openai' (+ 'stub'
    rejected with the usual guidance). Requesting 'vllm'/'transformers'
    now raises a clear 'not supported for now — use the HF Jobs flow'
    error. Removed _make_vllm_client and _make_transformers_client.
  * config: backend docstring updated (openai-only); default model_id
    bumped to Qwen/Qwen3.6-27B to match run_hf_job.
  * docs/annotation_pipeline.mdx: remove the '## Running locally'
    section; the launcher description now says one vLLM server per GPU
    over the OpenAI API, and the 'One Qwen-VL pass' note drops the
    'vLLM/transformers fallback' wording.

Tests are unaffected (they construct StubVlmClient directly; nothing
referenced the removed backends).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 16:28:40 +02:00
Pepijn a18d969753 tests(annotations): fix stale canned-VLM markers + action_record style assertion
The annotation tests had never actually run in CI (collection failed on
the missing 'datasets' extra); now that they do, three stale assertions
surfaced against the evolved pipeline:

  * test_module1_plan_memory_subtask_smoke: the memory canned-responder
    marker 'Update the memory' no longer appears in module_1_memory.txt
    (now 'compressed semantic memory'), so the stub returned no memory
    row and the {subtask,plan,memory} subset check failed. Marker
    updated to match the current prompt.
  * test_module2_mid_episode_emits_paired_interjection_and_speech: the
    interjection marker 'Write ONE interjection' is now 'Write ONE
    compact interjection' in module_2_interjection.txt, so 0 interjections
    were emitted. Marker updated.
  * tests/datasets/test_language.py::test_style_registry_routes_columns:
    PERSISTENT_STYLES gained 'action_record' in this PR; add it to the
    expected set.

These are test/prompt-marker syncs — no production behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 16:21:17 +02:00
Pepijn 273a8fc335 deps(annotations): drop hard vllm dependency to unblock CI torch/torchcodec resolution
Fast Pytest 'dataset' tier failed collecting tests/datasets/test_video_
decoder_cache.py with 'Could not load libtorchcodec ... undefined symbol:
torch_dtype_float4_e2m1fn_x2' — a torch/torchcodec ABI mismatch.

Root cause: the annotations extra's vllm hard-pins an older torch
(via xformers/xgrammar -> torch 2.8). uv resolves a SINGLE unified lock
across all extras, so vllm capped torch to 2.8 for every tier —
including dataset, whose torchcodec 0.11.1 needs torch 2.11. The
result was torch 2.8 + torchcodec 0.11.1 installed together -> ABI break.
(main has no vllm, so it resolves torch 2.11 + torchcodec 0.11.1 cleanly.)

Fix: remove vllm from the annotations extra. It is not needed by
the shipped workflow — examples/annotations/run_hf_job.py gets vllm from
the vllm/vllm-openai image and talks to it over the OpenAI-compatible
API (--vlm.backend=openai), and vlm_client._make_vllm_client imports vllm
lazily. For the in-process --vlm.backend=vllm path, install vllm
separately (the ImportError now says so).

After the fix uv resolves torch 2.11.0 + torchcodec 0.11.1 (matching
main); uv lock --check is clean. The annotations extra still provides
datasets / transformers / openai.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 16:09:22 +02:00
Pepijn b9246ef61b tests(annotations): guard on the 'dataset' extra so base fast-test tier skips cleanly
Fast Pytest Tests failed at COLLECTION in the base '--extra test' tier
with 'ModuleNotFoundError: No module named datasets': tests/annotations/
conftest.py imported the fixture dataset builder (-> lerobot.datasets ->
the HF 'datasets' lib + pandas/pyarrow), which only ship under the
'dataset' extra, so the whole annotations package crashed.

Fix uses the repo's proven module-level guard pattern (see
tests/datasets/test_language.py), NOT a conftest-level importorskip —
verified empirically that pytest.importorskip raised during conftest
*import* is treated as a collection ERROR (exit 1), while module-level
importorskip is a clean SKIP.

  * conftest.py: import build_annotation_dataset LAZILY inside the
    fixtures so the conftest itself imports cleanly in every tier.
  * test_modules / test_validator / test_writer / test_pipeline_recipe_
    render: add module-level pytest.importorskip('datasets') +
    ('pandas') before the pyarrow / lerobot.* imports (# noqa: E402 to
    match the existing convention). pyarrow-importing modules place the
    guard before the pyarrow import.
  * tests/scripts/test_lerobot_annotate.py: same guard (its _push_to_hub
    path imports lerobot.datasets).

Result:
  - base / hardware / viz tiers (no dataset extra): annotation tests
    skip cleanly; the rest of the suite runs -> exit 0.
  - dataset tier: datasets present -> guards pass through -> annotation
    tests run with the stub VLM. The pipeline modules import only
    stdlib + relative + lerobot.datasets (no module-level datatrove /
    vllm / openai), so they import fine there.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-03 15:57:04 +02:00
Pepijn 870980efd6 Merge branch 'main' into feat/language-annotation-pipeline 2026-06-03 15:46:13 +02:00
Pepijn 4c86332fe3 feat(annotate): add plan toggle, drop subtask verify pass, 4xH200 job
- PlanConfig.emit_plan (default True): keep subtasks + memory but skip
  the per-boundary "plan" rows and their VLM call when False.
- Remove the subtask_verify pass entirely: pruning dropped legitimate
  subtasks and the stitch step already guarantees full-episode coverage.
  Deletes _verify_subtasks, both call sites, and the now-unused
  module_1_subtask_verify prompt.
- run_hf_job example: 4xH200 (4 vllm servers), emit_plan=false, vqa off.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-02 18:02:13 +02:00
pepijn 23419026d5 pi052: parquet-direct FAST tokenizer fit (fix v3 dataset hang)
``fit_fast_tokenizer`` previously called ``LeRobotDataset(repo_id,
episodes=[N])`` per sampled episode, which on v3-format datasets
routes through HF datasets' split lookup and raises ``ValueError:
Instruction "train" corresponds to no data!`` on every episode. On
``pepijn223/robocasa_pretrain_human300_v4`` (32 k episodes) this looped
through 13,293 skipped episodes for ~2.5 h before the NCCL watchdog
killed the run via the 2 h ALLREDUCE timeout (job 22182985).

Switch to reading the ``action`` column directly from the dataset's
``data/chunk-*/file-*.parquet`` shards (same pattern as the audit
scripts). Verified end-to-end on the 32 k-episode dataset: 1000 chunks
collected from 1000 episodes in 70.7 s.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-02 15:54:31 +00:00
Pepijn 1417fd69b2 docs(annotate): prettier format annotation_pipeline.mdx
Quality-gate fix: ruff-format/markdown prettier hook reflow of the
annotation pipeline doc. No content change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 17:41:46 +02:00
Pepijn 53c7b4c69a annotate: ruff lint + format pass
Quality-gate fixes after the main merge:
  * UP037: drop redundant quotes from PlanConfig forward-ref annotations
    (action_records / task_aug_axes) — safe under 'from __future__ import
    annotations'.
  * ruff format applied to config.py, executor.py, general_vqa.py,
    plan_subtasks_memory.py, validator.py, lerobot_annotate.py.

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 17:38:18 +02:00
Pepijn 3662c41b85 Merge remote-tracking branch 'origin/main' into feat/language-annotation-pipeline
# Conflicts:
#	uv.lock
2026-06-02 17:36:07 +02:00
Pepijn 518e191337 annotate: windowed subtask generation for constant temporal density
Long episodes no longer get sparse subtasks. Previously a long episode
was subsampled to max_video_frames=32 across its whole duration (~1
frame/4s for a 2-min clip). New opt-in windowing keeps a CONSTANT
frames_per_second density by splitting the episode into fixed-length
windows and running the subtask chain per window.

New PlanConfig.subtask_window_seconds (default 0.0 = off). When > 0 and
the episode is longer than one window:
  * episode is split into consecutive [w0, w1] windows of this length
  * each window's frames are sampled at frames_per_second (so a 32s
    window at 1 fps = 32 frames, filling but not exceeding the per-call
    context budget)
  * the full describe -> segment -> verify chain runs PER window, in
    window-relative time [0, L]; spans are offset back to absolute
  * all windows' spans are merged, frame-snap-deduped, and stitched into
    one contiguous whole-episode cover

Implementation:
  * _episode_video_block / _video_message / _describe_episode /
    _verify_subtasks gain an optional window=(w0,w1); when set they
    embed frames sampled in that absolute range at frames_per_second
    (video_url path skipped — it's whole-episode).
  * _clean_spans gains bounds= (override clamp range, for window-relative
    spans) and dedupe= (skip frame-snap until the merged absolute set).
  * new _generate_subtasks_windowed + _subtasks_for_window orchestrate
    the loop; _generate_subtasks branches to them when window_s > 0.

run_hf_job.py: --plan.subtask_window_seconds=32 (32s windows at 1 fps).
Cost scales with episode length (chain calls × ceil(duration/window)).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 16:26:14 +02:00
Pepijn 3236c6ee4a examples(annotate): switch run_hf_job to Qwen3.6-27B (dense VLM)
Swap the annotation VLM from Qwen3.6-35B-A3B (sparse MoE, ~3B active)
to Qwen3.6-27B (dense, 27B all-active). Per Scale's dense-captioning
study, model capacity is the #1 lever and the dominant failure is
visual grounding — both helped by ~9x more active params. Qwen3.6-27B
is a vision-language model (vision encoder, image + video), same family
so the chat template / video handling / enable_thinking=false flag are
unchanged, and at 27B dense it still fits one H200 per server, so the
two-parallel-server layout (TP=1, one per GPU) is preserved — no
throughput-layout change, just a much stronger model.

Kept: parallel_servers=2, num_gpus=2, max-model-len 32768 (the 32-frame
embedded budget is ~10k tokens, well under), gpu-mem 0.8.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 16:16:26 +02:00
Pepijn cd128cbbd5 annotate: add verb-scoped disambiguation rules to subtask prompt
Adopt the one prompt technique Scale's dense-captioning study found
reliably positive: targeted, verb-scoped, visually-grounded
disambiguation rules. Their lesson was that such a rule must fire ONLY
on the spatial situation it names (their narrow 'Stack vs Put' rule
helped; an over-broad directional 'Scoop' rule bled into other verbs
and hurt), so each rule here is phrased visually and scoped to one
confusable pair:
  * stack-vs-put (on top of an object vs on a surface)
  * insert-vs-put (fitted slot vs surface)
  * pick-up/retrieve-vs-put (decide by which way the OBJECT moves:
    gripper closes + object moves with hand = pick up; gripper opens +
    object stays = put — directly targets Scale's dominant
    direction-flip failure)
  * pour-vs-put (tilt + flow vs untilted move)

This is the highest-confidence, lowest-risk change from the Scale
findings; our pipeline already aligns with their 'avoid' list (no
temporal tokens, no overlays, no fancy sampling, no sequential context
injection, uniform sampling, describe-don't-predict framing).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 16:10:49 +02:00
Pepijn 1fb46ab300 annotate: cap embedded-frame budget to fit VLM context (fix 32k overflow)
Switching the plan module to embedded frames (use_video_url=false)
exposed a context overflow: at frames_per_second=2.0 with the old
max_video_frames=128 default, a 480x640 episode embeds ~128 frames ≈
33-39k vision tokens, over the model's 32768 context — every plan call
died with 'Input length exceeds maximum context length' (HTTP 400),
crashing the whole annotation job.

The video_url path never hit this because the server downsampled; the
embedded path sends every sampled frame, so the frame count is a hard
token budget.

Fix:
  * config default max_video_frames 128 -> 32 (~8-10k vision tokens,
    comfortable headroom for the prompt + describe/verify passes).
    Frames are still sampled UNIFORMLY across the whole episode, so
    longer episodes are subsampled, not truncated — full temporal
    coverage preserved, just coarser density.
  * run_hf_job.py: frames_per_second 2.0 -> 1.0, explicit
    --plan.max_video_frames=32, with a comment explaining the token
    budget and the 'do not raise toward 128 with embedded frames' rule.

Only the plan module embeds the full episode; VQA (1 frame/tick) and
interjections (4-frame window) were never at risk.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 16:02:25 +02:00
Pepijn 79f9a84407 annotate: make full-episode subtask coverage unconditional
Remove the subtask_full_coverage config flag. Stitching subtask spans
into a contiguous full-episode cover is now always applied in
_generate_subtasks — a sparse / gap-ridden subtask timeline is never
desirable for conditioning, so there's no reason to make it optional.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 15:36:23 +02:00
Pepijn 799d0e3bcc annotate: stitch subtasks to full-episode coverage
The verify pass prunes subtasks, which could leave the first subtask
starting after t0 or leave gaps between spans — so the subtask timeline
no longer tiled the episode and frames fell through with no active
subtask label.

New deterministic post-step (no VLM call), default on via
PlanConfig.subtask_full_coverage:
  * first subtask start pulled back to the episode's first frame t0
    (idle / approach before the first labelled action folds into it)
  * each subtask end snapped to the next subtask start (gaps closed)
  * last subtask end extended to the last frame t_last

Runs after segment + verify in _generate_subtasks. Starts other than
the first are left as the VLM/verify produced them (already frame-
snapped + distinct), so the cover is contiguous and non-overlapping.
Disable with --plan.subtask_full_coverage=false if a consumer wants
sparse subtasks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 15:34:34 +02:00
Pepijn 1fe1463ae0 annotate: enable subtask describe->segment->verify chain by default
Flip PlanConfig.subtask_describe_first and subtask_verify defaults
False -> True. Every subtask annotation now runs the 3-call grounding
+ pruning chain by default, since the single-call path reliably
hallucinates steps from the task text. Costs 2 extra VLM calls/episode;
disable with --plan.subtask_describe_first=false / --plan.subtask_
verify=false on easy datasets where fewer calls matter more than
label fidelity.

run_hf_job.py: drop the now-redundant explicit flags, leave a note that
the chain is default-on and how to opt out.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 15:13:50 +02:00
Pepijn dcd368e1f8 annotate: multi-call subtask quality chain (describe -> segment -> verify)
The single-call 'watch video -> emit subtask JSON' pattern makes the
VLM commit to structured output before reasoning about what it saw, so
it pattern-matches the task text and hallucinates steps. Split it into
an opt-in multi-call chain that grounds first and prunes last.

New PlanConfig flags (both default False -> single-call unchanged):
  * subtask_describe_first: a grounding pass narrates ONLY what is
    visible in the video (no subtask JSON yet). That description is
    injected into the segmentation prompt via a new {observation_block}
    placeholder, so the model segments its own grounded observations
    instead of the instruction text. +1 VLM call/episode.
  * subtask_verify: after segmentation, an adversarial pass re-watches
    the video and drops any candidate subtask it cannot see. Can only
    PRUNE (never add/rewrite/move) and fails open (keeps un-verified
    spans if the call returns nothing). +1 VLM call/episode.

Implementation:
  * _generate_subtasks now orchestrates describe -> segment -> verify.
  * Factored span cleaning into _clean_spans (shared by segment + verify
    outputs); added _describe_episode and _verify_subtasks helpers.
  * New prompts module_1_subtask_describe.txt (returns {description})
    and module_1_subtask_verify.txt (returns pruned {subtasks}).
  * module_1_subtasks.txt gains a {observation_block} slot at the top.

run_hf_job.py enables both for the RoboCasa run (3 VLM calls/episode
for subtasks). Combined with single-camera grounding + the embedded-
frame path, this is the high-quality configuration.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 15:12:46 +02:00
Pepijn ba5d4c5cd8 annotate: kill subtask hallucination + single-camera grounding
Two fixes for 'subtasks describe actions not in the video' plus a way
to focus the whole pipeline on one camera.

ANTI-HALLUCINATION
  1. _episode_video_block: when use_video_url is set but clip extraction
     fails, FALL BACK to embedded frames instead of returning an empty
     block. An empty block left the VLM with zero visual grounding, so
     it invented subtasks from the task text alone — the likely root
     cause of hallucinated steps. Now logs a warning and embeds frames.
  2. module_1_subtasks.txt gains a GROUNDING preamble (overrides all
     other rules): label only motion visible in specific frames; never
     invent/anticipate/pad; max_steps is a CEILING not a target; atomic
     demos may be exactly ONE subtask; the VIDEO is ground truth, not
     the instruction text.

SINGLE-CAMERA GROUNDING
  * New VqaConfig.restrict_to_default_camera (default False). When True,
    the VQA module grounds on only the --vlm.camera_key stream instead
    of iterating every camera — matching the plan / interjection
    modules, which already use that single camera. Now the whole
    pipeline can focus on one view (e.g. observation.images.base).

run_hf_job.py updated:
  * use_video_url=false + frames_per_second=2.0 — embed frames directly
    (most reliable; no silent text-only failure mode) with dense
    grounding.
  * vqa.restrict_to_default_camera=true — VQA on the single camera too.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 15:08:25 +02:00
Pepijn 7454b4c993 annotate: remove action-record subtask-text replacement entirely
Drops the replace_subtask_text option and the
_render_action_record_to_subtask_text renderer. Action records are now
strictly additive: when action_records.enabled=True the module emits
style='action_record' rows (the typed {verb,object,arm,grasp,dest,
mistake} schema) and NEVER rewrites the subtask text the policy
conditions on.

The render-back-to-text path was the source of corrupted subtasks
(navigation tasks produced 'move stove to stove', manipulation tasks
got spurious 'with left arm using pinch grip' suffixes). Reconstructing
natural-language subtasks from hallucinated structured fields is
inherently fragile, so the capability is removed rather than guarded.

Removed:
  * ActionRecordsConfig.replace_subtask_text field
  * PlanSubtasksMemoryModule._render_action_record_to_subtask_text
  * the span['text'] = canonical_text overwrite in run_episode

Updated docstrings + run_hf_job.py comment accordingly. emit_record_row
(default True) is now the feature's only output.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 14:42:36 +02:00
Pepijn c5042a6850 fix(annotate): stop action records + augmentation from corrupting RoboCasa labels
Three compounding bugs made RoboCasa annotation produce off-task
subtasks ('move stove to stove with left arm') and drifting
augmentations ('wander around the kitchen' for 'Navigate to the stove').

1. action_records.replace_subtask_text now defaults False.
   Overwriting the VLM's subtask text with a reconstruction of
   hallucinated {verb,object,arm,grasp,dest} fields is high-risk:
   navigation / non-manipulation tasks don't fit the schema and render
   to nonsense. Records are now additive by default (emit_record_row),
   never silently replacing subtask text. Flip replace_subtask_text on
   only for manipulation datasets verified to render cleanly.

2. _render_action_record_to_subtask_text drops a degenerate
   destination that just echoes the object (verb=move object=stove
   destination=stove -> 'move stove' instead of 'move stove to stove').
   Also routes 'navigate' through the 'to <dest>' preposition family.

3. module_1_task_aug_axes.txt hardened: variants MUST preserve the
   goal/destination. Explicitly forbids 'Navigate to the stove' ->
   'wander around the kitchen'. Only wording / arm / orientation /
   grasp may vary; verb meaning, object, and destination are fixed.

examples/annotations/run_hf_job.py — corrected for RoboCasa:
  * derive_task_from_video=off (was =always). The dataset task string
    is authoritative and is what eval conditions on; =always threw it
    away, re-derived a hallucinated task from the video, and poisoned
    every downstream subtask/plan row. THIS was the dominant cause.
  * n_task_rephrasings=0 + task_aug_axes left off — RoboCasa eval uses
    exact task strings, so augmentation is unused/harmful.
  * action_records left off — manipulation schema doesn't fit atomic /
    navigation tasks.
  * plan_max_steps=6 to keep atomic-task decomposition tight.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 14:34:48 +02:00
pepijn223 ff1d58a46f pi052: suppress FAST action tokens in select_message text generation
The FAST action tokenizer maps action codes to the top of the PaliGemma
vocab (id = vocab_size-1-fast_skip_tokens-t). The lower part of that band
sits just below the reserved <loc> block, so it escaped the existing
suppress_loc_tokens mask and leaked into generated subtask/VQA/memory text
as high-codepoint gibberish. Mask the FAST band on every select_message
call so the high-level head emits clean language.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-02 13:07:02 +02:00
Pepijn 98a519e7f2 fix(annotate): default frame provider to video keys, not image keys
VideoFrameProvider derived its default camera and camera list from
meta.camera_keys, which mixes image- and video-stored cameras. The
clip/decode paths read videos/<key>/from_timestamp, which only exists
for video keys, so an image-stored camera sorted first (e.g.
observation.images.wrist) crashed the plan phase with a KeyError.

Restrict the list and default to meta.video_keys. Add a regression test
and point the example job at the dataset's actual video camera. Skip
bandit B607 (ffmpeg/git are intentionally resolved via PATH).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-02 12:09:55 +02:00
Pepijn 5dbf0fac5f annotations(steerable): remove Phase 0 canonical vocabulary discovery
Drops the optional Phase 0 vocabulary-discovery feature entirely.
With the new structured action records (Phase 1a + 1b) providing
cross-episode consistency via the deterministic template renderer,
the older vocabulary-constraint path is redundant and adds a second
constraint mechanism that wasn't well-validated in practice.

Removed:
  * src/lerobot/annotations/steerable_pipeline/vocabulary.py
    (Vocabulary dataclass + VocabularyDiscoveryModule + load_/
    save_vocabulary helpers; canonical_vocabulary.json on-disk format)
  * src/lerobot/annotations/steerable_pipeline/prompts/module_0_vocabulary.txt
    (Phase 0 VLM prompt)
  * tests/annotations/test_vocabulary.py

Pruned wiring across:
  * config.py: VocabularyConfig dataclass + AnnotationPipelineConfig.
    vocabulary field
  * executor.py: vocabulary attribute on Executor + _run_vocabulary_
    phase method + Phase 0 phases.append call in run()
  * modules/plan_subtasks_memory.py: Vocabulary import + vocabulary
    attribute + _subtask_vocabulary_block / _memory_vocabulary_block
    helpers + _canonicalize_subtask / _normalize / _invalid_subtasks
    / _build_subtask_retry_message methods + vocabulary-gated retry
    path in _generate_subtasks + empty-episode warning + _NORMALIZE_
    STRIP_TOKENS constant
  * prompts/module_1_subtasks.txt: {vocabulary_block} placeholder
  * prompts/module_1_memory.txt: {vocabulary_block} placeholder
  * __init__.py: Vocabulary / VocabularyDiscoveryModule / load_
    vocabulary / save_vocabulary / vocabulary_path / VOCABULARY_
    FILENAME re-exports
  * scripts/lerobot_annotate.py: VocabularyDiscoveryModule import +
    instantiation + executor argument
  * examples/annotations/run_hf_job.py: --vocabulary.enabled=false
    flag + docstring references + inline phase-0 comment

The original free-form rephrasings path stays (PlanConfig.
n_task_rephrasings still works when task_aug_axes.enabled=False).
Action records remain the preferred mechanism for cross-episode
subtask consistency.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 11:48:27 +02:00
Pepijn 2bfaf44db2 annotations(steerable): structured action records + 5-axis task augmentation
EgoMimic-inspired additions to the plan module, both opt-in for back-compat.

1. PHASE 1a + 1b: per-subtask structured action records
   * cfg.action_records.enabled=True triggers, after Phase 1 subtask-span
     generation, one extra VLM call per subtask to extract a typed record:
       {verb, object, arm, grasp_type, destination, mistake}
   * A deterministic Python template (_render_action_record_to_subtask_text)
     renders the record back to canonical subtask text. When replace_subtask_
     text=True (default), this REPLACES the VLM's free-form text — eliminates
     cross-episode phrasing drift.
   * When emit_record_row=True (default), the structured record is also
     emitted as a row with style='action_record' (added to PERSISTENT_STYLES)
     so downstream training can consume the typed schema directly.
   * Verb + grasp vocabularies are configurable. Out-of-vocab values are
     rejected at extraction time.

2. STRUCTURED 5-AXIS TASK AUGMENTATION
   * cfg.task_aug_axes.enabled=True replaces the free-form n_task_rephrasings
     path with a structured prompt producing variants along 5 named axes:
       synonym_paraphrase (3)
       omit_arm           (3)
       omit_orientation   (2)
       omit_grasp_method  (2)
       combined_omissions (2)
     Total ~12 variants. Axes with nothing to omit emit fewer entries.
   * Each variant is emitted as a task_aug row at t=0 (existing style).

Inspired by https://github.com/GaTech-RL2/EgoVerse/tree/main/egomimic/scripts/language_process
— they pay Scale AI annotators to fill a structured form and then generate
language via a deterministic prompt. We get the same hallucination-reducing
structure via one extra VLM call per subtask.

Files:
  src/lerobot/datasets/language.py
  src/lerobot/annotations/steerable_pipeline/config.py
  src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py
  src/lerobot/annotations/steerable_pipeline/prompts/module_1_action_record.txt
  src/lerobot/annotations/steerable_pipeline/prompts/module_1_task_aug_axes.txt

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 11:35:35 +02:00
Pepijn d04ea0ea8a annotations(steerable): structured action records + 5-axis task augmentation
EgoMimic-inspired additions to the plan module, both opt-in for back-compat.

1. PHASE 1a + 1b: per-subtask structured action records
   * cfg.action_records.enabled=True triggers, after Phase 1 subtask-span
     generation, one extra VLM call per subtask to extract a typed record:
       {verb, object, arm, grasp_type, destination, mistake}
   * A deterministic Python template (_render_action_record_to_subtask_text)
     renders the record back to canonical subtask text. When replace_subtask_
     text=True (default), this REPLACES the VLM's free-form text — eliminates
     cross-episode phrasing drift.
   * When emit_record_row=True (default), the structured record is also
     emitted as a row with style='action_record' (added to PERSISTENT_STYLES)
     so downstream training can consume the typed schema directly.
   * Verb + grasp vocabularies are configurable. Out-of-vocab values are
     rejected at extraction time.

2. STRUCTURED 5-AXIS TASK AUGMENTATION
   * cfg.task_aug_axes.enabled=True replaces the free-form n_task_rephrasings
     path with a structured prompt producing variants along 5 named axes:
       synonym_paraphrase (3)
       omit_arm           (3)
       omit_orientation   (2)
       omit_grasp_method  (2)
       combined_omissions (2)
     Total ~12 variants. Axes with nothing to omit emit fewer entries.
   * Each variant is emitted as a task_aug row at t=0 (existing style).

Inspired by https://github.com/GaTech-RL2/EgoVerse/tree/main/egomimic/scripts/language_process
— they pay Scale AI annotators to fill a structured form and then generate
language via a deterministic prompt. We get the same hallucination-reducing
structure via one extra VLM call per subtask.

Files:
  src/lerobot/datasets/language.py
  src/lerobot/annotations/steerable_pipeline/config.py
  src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py
  src/lerobot/annotations/steerable_pipeline/prompts/module_1_action_record.txt
  src/lerobot/annotations/steerable_pipeline/prompts/module_1_task_aug_axes.txt

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-06-02 11:31:42 +02:00
pepijn223 bb2c09965b pi052: hierarchical select_action + RoboCasa eval video overlay
- modeling_pi052: per-env low-level subtask generation in select_action so
  hierarchical inference is correct for eval.batch_size > 1
- render_messages_processor: always emit a fallback low-level prompt so
  observation.language.tokens are produced when recipe annotations are absent
- lerobot_eval: overlay high-level task + predicted subtask onto recorded
  rollout videos (render path only; does not affect policy observations)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-01 14:35:13 +02:00
pepijn 1f1541243a pi052: make `lerobot-eval` work on saved checkpoints
pi052's preprocessor pipelines don't roundtrip through the saved
``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
``TrainingRecipe`` Python object (not JSON-serializable, saved as
``{}``) and ``ActionTokenizerProcessorStep`` saves the fitted FAST
tokenizer's host-only ``~/.cache/lerobot/fast_tokenizers/...`` path.
``PolicyProcessorPipeline.from_pretrained`` then dies with
``RenderMessagesStep.__init__() missing 1 required positional
argument: 'recipe'`` (job 22164494).

The pi052 training path was workable because the recipe-aware steps
were built directly; the runtime path
(``lerobot.scripts.lerobot_pi052_runtime``) sidesteps the loader by
passing ``pretrained_path=None`` to ``make_pre_post_processors`` and
building fresh from ``config.recipe_path``. The standard
``lerobot-eval`` entry point had no such escape hatch.

Two surgical fixes:

* ``factory.make_pre_post_processors``: when ``policy_cfg.type ==
  "pi052"`` AND ``pretrained_path`` is set, bypass the generic
  ``PolicyProcessorPipeline.from_pretrained`` call. Build the
  pipelines fresh via ``make_pi052_pre_post_processors`` (same
  bootstrap the runtime uses) and transplant the saved stateful
  blobs from each step's ``state_file`` reference in the saved JSON
  (today: NormalizerProcessorStep + UnnormalizerProcessorStep
  quantile stats). Pairing is by ``registry_name`` AND position so
  a benign reorder logs a warning instead of silently mis-loading.

* ``PI052Config.use_hf_kernels``: re-add as a deprecated no-op
  field. The flag was removed in d70c8104 (Liger kernels became
  unconditional), but checkpoints saved before that commit
  serialize ``use_hf_kernels: true`` into ``config.json``. Without
  this field draccus rejects the load with ``DecodingError: The
  fields use_hf_kernels are not valid for PI052Config`` (job
  22164492). Mark for removal in a future major bump.

Together these let an external ``lerobot-eval --policy.path=<pi052
checkpoint>`` invocation evaluate the model against any env.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-27 09:14:34 +00:00
pepijn d70c810416 pi052: drop `use_hf_kernels` flag — always patch Liger kernels
The flag gated a process-global, idempotent Liger patch that swaps
in fused Triton rope / geglu / layer_norm kernels (~4.5 % step time
on H100, bench job 22161421). Since liger-kernel is now a hard
dependency of the loss path (``_shifted_lin_ce`` / ``_fast_lin_ce``
in ``modeling_pi052``), gating the same dep behind an opt-in flag
was redundant — every pi052 run pulls the wheel in either way.

* ``PI052Policy.__init__`` calls ``_enable_hf_kernels()``
  unconditionally; the function still degrades gracefully if the
  wheel happens to be missing (logs a warning, returns).
* Drop ``PI052Config.use_hf_kernels``; the bench numbers and the
  ``fused_linear_cross_entropy`` pointer to ``_shifted_lin_ce`` /
  ``_fast_lin_ce`` are kept as comments next to the docstring.
* Update the warning + ``_shifted_lin_ce`` lazy-import comment to
  drop stale ``use_hf_kernels`` / ``reduce-overhead`` references.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 11:47:49 +00:00
pepijn 4c3ddb1ff5 pi052: wire Liger fused linear CE + DDP-safe FAST tokenizer fit
* Replace ``_shifted_ce`` / ``_fast_ce`` with Liger's
  ``fused_linear_cross_entropy``: the ``(B, T, 257k)`` logits tensor
  is no longer materialised — the kernel chunks over the ``(B*T)``
  axis and computes matmul + softmax + CE in fused Triton blocks.
  ~30 % step speedup and ~12 GB of activation memory freed on the
  dual-CE pi052 recipe. All four call sites in
  ``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss``
  updated; the ``.any().item()`` CPU sync is dropped so the loss
  path stays CUDA-graph-capturable.

* DDP-safe FAST tokenizer fit. The cache-hit sentinel previously
  looked for ``preprocessor_config.json`` but
  ``ProcessorMixin.save_pretrained`` writes ``processor_config.json``
  — every rank always cache-missed and re-fit, racing on writes and
  occasionally producing a stale ``.pyc`` that crashed
  ``AutoProcessor.from_pretrained`` with ``AttributeError:
  UniversalActionProcessor``. Fix the sentinel; gate the fit on the
  (local) main process; non-leader ranks poll the cache until the
  leader is done. Caught by job 22162549.

* New recipe ``subtask_mem_vqa_robocasa.yaml`` — subtask + memory +
  per-camera VQA over the three robocasa camera keys produced by the
  port pipeline (``robot0_agentview_left/right``, ``robot0_eye_in_hand``).
  The previously-shipped ``subtask_mem_vqa_speech.yaml`` references
  ``observation.images.front`` / ``wrist`` which don't exist in
  robocasa, so VQA never rendered.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 11:18:16 +00:00
pepijn 8615f3f613 annotate(vqa): tighten bbox + keypoint quality bar
Low-confidence VLM detections were producing many overlapping, loose
boxes per frame (oven + toaster oven + counter + drawer + ...) and
coarse keypoints, hurting downstream policy grounding. Two surgical
fixes:

- module_3_vqa prompt: cap bbox at most 3 high-confidence detections
  (prefer 1 tight box), require specific labels and ≤10% padding,
  allow empty detections list when nothing meets the bar; keypoint
  must be a single pixel-precise feature (handle / button / gripper
  tip) rather than a coarse "somewhere on object" point.
- run_hf_job: lower vlm.temperature 0.7 → 0.2. Bbox + keypoint are
  coordinate-regression tasks where sampling noise directly degrades
  localization; question phrasing still varies enough at 0.2.

No new config knobs — the count cap lives in the prompt since "top-N
by confidence" is best picked by the VLM itself. Validator already
accepts empty detections.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 08:31:37 +00:00
pepijn 1e7c0d6aa1 annotate(plan): force composite-action subtasks; ban ultra-fine splits
Tighten ``module_1_subtasks.txt`` so the VLM emits one composite
atomic action per subtask instead of decomposing every pick into
``move to X`` / ``grasp X`` / ``lift X``:

- Lock the verb vocabulary to the composite set the low-level
  policy actually learns end-to-end: ``pick up`` (approach + grasp +
  lift), ``put``/``place`` (transport + release), ``push``, ``pull``,
  ``turn``, ``press``, ``open``, ``close``, ``pour``, ``insert``.
  ``go to`` is allowed only as a pure relocation between phases.
- Add an explicit ``Forbidden ultra-fine splits`` block enumerating
  the patterns the VLM was tempted to emit (``move to X``,
  ``reach for X``, ``grasp X``, ``lift X``, ``release X``) and
  instructing it to fold each into its parent composite.
- Rewrite the Good/Bad examples to match the composite contract;
  the previous ``"move to blue cube" / "grasp blue cube" / "lift
  blue cube"`` Good list was actively encouraging the over-
  segmentation pattern this prompt is supposed to prevent.
- Tighten the duration rule: candidates shorter than
  ``min_subtask_seconds`` must be merged into a neighbour rather
  than emitted. Pairs with bumping the runtime floor to 3 s so
  composites have room to land.

Pure prompt change — no code or schema change. Existing canonical-
vocabulary retry path is unaffected (the new verb whitelist lives
in prose, not in the validator).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 05:14:30 +00:00
pepijn 2686450d68 annotate(plan): force composite-action subtasks; tune run_hf_job for robocasa_smoke
Subtask prompt (``module_1_subtasks.txt``):
- Lock the verb vocabulary to composite atomic actions (``pick up``,
  ``put``/``place``, ``push``/``pull``, ``turn``, ``press``, ``open``/
  ``close``, ``pour``, ``insert``, ``go to``).
- Add an explicit ``Forbidden ultra-fine splits`` block instructing
  the VLM to fold ``move to X`` / ``reach for X`` / ``grasp X`` /
  ``lift X`` / ``release X`` into the parent composite. Previous
  examples actively encouraged the over-segmentation pattern.
- Rewrite the Good/Bad examples around the composite contract.

Job config (``examples/annotations/run_hf_job.py``):
- Point at ``pepijn223/robocasa_smoke_2atomic_v3`` on ``h200x4``.
- ``--vlm.camera_key=robot0_agentview_left`` (real key for the
  dataset; the prior ``observation.images.wrist`` did not exist
  and would have silenced the VQA module).
- ``--vlm.serve_command`` ``--max-model-len 131072`` (4x): keeps
  90 s @ 1 Hz episode video blocks under context even at full
  Qwen vision resolution. On 1x H200 (144 GB) the 35B-FP8 model
  has plenty of room for the bigger KV cache.
- ``--vocabulary.enabled=false`` — heterogeneous dataset, no
  benefit from a single canonical vocabulary.
- ``--plan.derive_task_from_video=off``, ``--plan.n_task_rephrasings=0``
  — reuse the dataset's own ``episode_task`` strings as-is.
- ``--plan.min_subtask_seconds=3.0``, ``--plan.plan_max_steps=6`` —
  give the new composite-action rules room to land (1.5 s floor
  was too small to host a full grasp-or-place composite).
- ``--vqa.vqa_emission_hz=3.0`` — denser VQA grounding.
- Timeout 24h, episode_parallelism=64, client_concurrency=256 to
  scale to the 25k-trajectory regime when the same recipe is
  pointed at a larger dataset.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 05:14:23 +00:00
pepijn 920c6ef5a2 docs(annotate): disable phase-0 vocabulary discovery by default in run_hf_job
Heterogeneous datasets (different tasks/scenes across episodes) don't
share a single small subtask + memory vocabulary, so the canonical
vocabulary phase narrowed every episode to the wrong target distribution.
Flip the example to free-form generation by default and document the
``--vocabulary.enabled=true`` switch for homogeneous datasets where the
canonical vocabulary still helps the downstream policy.

No pipeline-code changes: ``VocabularyConfig.enabled`` already gates
phase 0 (see ``executor.py:_run_vocabulary_phase`` and
``VocabularyConfig`` docstring) and falls back to free-form generation.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 04:42:10 +00:00
pepijn 4913356564 pi052: SDPA attention port + selective AC + bench harness
Replaces the per-layer ``modeling_gemma.eager_attention_forward`` call
with ``torch.nn.functional.scaled_dot_product_attention`` in
``compute_layer_complete`` (pi05) and ``_compute_layer_ki`` (pi052).
PyTorch SDPA picks the memory-efficient kernel for the
block-bidirectional 4D additive mask the dual-expert model uses (FA2 /
FA3 reject it because they only accept causal / sliding-window / varlen
patterns). The shared ``sdpa_attention_forward`` helper mirrors the
eager signature so the call sites are unchanged.

Selective AC: removes the redundant outer ``_apply_checkpoint(forward_func, ...)``
wrap in ``PI05Pytorch.forward``. Per-layer checkpointing inside
``PaliGemmaWithExpertModel.forward`` already handles activation
recompute; the outer wrap was double-recomputing the whole backbone.
+14% steps/sec on its own (job 22161405 vs 22161398, 1xH100).

groot: drop ``@strict`` on ``GR00TN15Config`` — newer ``huggingface_hub``
rejects ``@strict`` on non-dataclass ``PretrainedConfig`` subclasses,
which was blocking imports of any sibling policy through
``lerobot.policies.factory``.

New ``examples/benchmark/bench_pi052_step.py`` (+ slurm sweeps v1..v8)
times PI052Policy.forward+backward (optionally with AdamW) on
synthetic inputs. Headline numbers on 1xH100 with KI=True, GC=True,
L=512, 4.14 B trainable params, AdamW state in bf16:

  pre-SDPA eager BS=8                 610ms   19.5 GiB  ->  13.1 samples/s
  sdpa  BS=8  + compile=default       413ms   19.5 GiB  ->  19.3 samples/s
  sdpa  BS=16 + compile=default       715ms   37.3 GiB  ->  22.4 samples/s
  sdpa  BS=32 + compile=default      1325ms   44.8 GiB  ->  24.2 samples/s
  sdpa  BS=40 + compile=default      1665ms   48.6 GiB  ->  24.0 samples/s

Parity tests in ``tests/policies/pi052/test_pi052_sdpa_attention.py``
cover fp32 / bf16 / GQA / MHA forward + backward — output and grads
match the eager path within bf16 tolerance.

Also ships ``examples/benchmark/fsdp_pi052.yaml`` (FSDP2 accelerate
config wrapping GemmaDecoderLayer + SiglipEncoderLayer) for the
follow-up multi-GPU memory sharding work.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 21:59:20 +00:00
pepijn 673cc6b0fe pi052: opt-in Liger fused kernels (rope + geglu + layer_norm)
Adds ``PI052Config.use_hf_kernels`` (default off). When enabled,
``PI052Policy.__init__`` calls ``apply_liger_kernel_to_paligemma``
before the backbone is built so PaliGemma / Gemma / Siglip layers
pick up Liger's fused Triton forwards.

Measured at BS=16 / L=512 / H100 80GB with KI+GC on (bench job
22161421, see ``examples/benchmark/bench_pi052_kernels.slurm``):

  rope only        →  -2.5% step time
  geglu only       →  -2.2% step time
  layer_norm only  →  -1.1% step time
  all three        →  -4.5% step time, peak_mem unchanged

``cross_entropy`` / ``fused_linear_cross_entropy`` are deliberately
skipped — pi052 calls ``F.cross_entropy`` directly and bypasses
``PaliGemmaForConditionalGeneration.forward``, so neither patch
fires without invasive model-code changes (left for a follow-up).
``rms_norm`` measured as noise on this workload (GC dominates),
so it stays off to keep the patch surface minimal.

Requires ``pip install liger-kernel``; falls back to a warning if
missing so the default path is unaffected.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 20:50:07 +00:00
Pepijn 2ed6519a93 ema: enable by default (matches openpi JAX behavior)
Flip EMAConfig.enable default from False -> True. Every training run
now maintains an EMA shadow of the policy and uses it for eval + W&B
example dumps. Disable per-run with --ema.enable=false for short or
memory-constrained training.

Rationale:
  * openpi (JAX, official) ships EMA on for every shipped config,
    decay=0.99 by default and 0.999 for pi05_libero. The openpi
    PyTorch port explicitly lists EMA as unsupported, a gap LeRobot
    main inherited. Flipping the default closes that gap for every
    LeRobot policy that ships through lerobot-train.
  * EMA is established best practice for diffusion / flow-matching
    policies (Diffusion Policy §V.D; standard in DDPM/EDM/Stable
    Diffusion training recipes). For autoregressive policies the
    extra cost is real but the safety net (smoother eval, better
    final checkpoint) doesn't hurt.

Trade-offs to be aware of:
  * Memory: 1x model params in fp32 shadow (~13 GB for pi052's
    3.3B params; <500 MB for ACT/Diffusion-Policy class). Memory-
    constrained users on consumer GPUs may need --ema.enable=false.
  * Checkpoint disk: extra .pt file in training_state/, size ~=
    pretrained_model/model.safetensors. Over a 100k-step run with
    save_freq=20000 that's 5x the model size in extra disk.
  * Eval scores will now reflect EMA model instead of live model -
    expected to be 1-3% higher on closed-loop tasks per the
    diffusion-policy literature; might surprise users who memorize
    their last run's numbers.

Opt out:
  --ema.enable=false           # disable entirely
  --ema.use_for_eval=false     # keep EMA but eval reflects live
  --ema.use_for_wandb_examples=false   # keep EMA but W&B reflects live

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 21:58:46 +02:00
Pepijn 72ea531017 train: switch EMA from custom ModelEMA to ema-pytorch
Replace the 250-line src/lerobot/utils/ema.py with a direct dependency
on ema-pytorch (lucidrains' canonical PyTorch EMA library). Same
semantics, decay=0.999 default unchanged, but offloads the maintenance
burden to a maintained library used by every diffusion repo.

Why ema-pytorch:
  * Standard PyTorch EMA library; battle-tested across diffusion +
    speech + image-gen codebases.
  * Tiny pure-python dep (no compiled code).
  * Cleaner consumer-side API: ema.ema_model is a full nn.Module
    clone of the policy, so eval / wandb just pass it through instead
    of context-managed swap/restore on the live model.

What changed mechanically:
  * pyproject.toml: add 'ema-pytorch>=0.7.7,<1.0.0' to core deps.
  * deleted src/lerobot/utils/ema.py (the custom ModelEMA).
  * scripts/lerobot_train.py:
      - import EMA from ema_pytorch
      - instantiate with beta=cfg.ema.decay,
        update_after_step=cfg.ema.warmup_steps, update_every=1,
        include_online_model=False (accelerator owns live model
        lifecycle; double-registration would double-count params).
      - ema.update() (no args) — library tracks the online model
        internally.
      - Eval block: pass eval_target_policy = ema.ema_model (when
        cfg.ema.use_for_eval) instead of swap context manager.
      - W&B examples: same pattern.
      - Save: torch.save(ema.state_dict(), .../ema_state.pt) instead
        of custom safetensors writer. .pt format is consistent with
        the rest of training_state which already mixes safetensors +
        json + (now) pt.
      - Resume: ema.load_state_dict(torch.load(.../ema_state.pt)).
      - WandB observability: ema/step (count of ema.update calls),
        ema/initted (bool from library), ema/beta (constant from
        cfg).
  * configs/default.py: EMAConfig.decay stays 0.999 (matches
    openpi's pi05_libero); docstring updated to reflect ema-pytrch
    semantics for warmup_steps (now maps to update_after_step — a hard
    skip, not a smooth decay ramp).

Behavior preserved:
  * Defaults: enable=False, decay=0.999, warmup_steps=0,
    use_for_eval=True, use_for_wandb_examples=True.
  * Same CLI: --ema.enable=true, --ema.decay=X, etc.
  * Same checkpoint layout (training_state/ema_state.pt next to
    optimizer_state.safetensors etc.); resumes silently if present.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 21:51:23 +02:00
Pepijn 56a934ec55 train: EMA of policy parameters (opt-in via --ema.enable=true)
Adds Exponential Moving Average of trainable policy parameters with
warmup, eval-time swap, checkpoint save/resume, and wandb observability.

For diffusion / flow-matching policies (pi052's flow expert exactly
qualifies), averaging late-training parameter oscillations yields a
smoother model that generalises substantially better at inference —
~1–3% absolute success-rate improvement on closed-loop tasks per the
diffusion-policy lit (Chi et al. 2023 §V.D; standard in DDPM/EDM).

New module: src/lerobot/utils/ema.py
  ModelEMA class with:
    * fp32 shadow of every requires_grad parameter
    * decay warmup: min(decay, (1+n)/(10+n)) for first warmup_steps updates
    * update(model) -> effective_decay (for logging)
    * apply_to(model) context manager: temp-swap weights, restore on exit
    * copy_to(model): permanent overwrite
    * save() / load_from_file(): safetensors + JSON sidecar for metadata
    * state_dict() / load_state_dict() for in-process round-tripping

New config: src/lerobot/configs/default.py EMAConfig + wired into
TrainPipelineConfig as 'ema: EMAConfig'.
  Fields:
    enable: bool = False         (off by default, back-compat)
    decay: float = 0.999         (standard; 0.75 for fast Diffusion-Policy)
    warmup_steps: int = 0        (no warmup by default)
    use_for_eval: bool = True    (eval swaps in EMA weights)
    use_for_wandb_examples: bool = True
                                 (W&B training-examples table uses EMA
                                  for predicted-action columns -> matches
                                  what eval / deployment would see)

Training loop integration (src/lerobot/scripts/lerobot_train.py):
  1. After accelerator.prepare + policy.train(), instantiate ModelEMA
     on the main process if cfg.ema.enable. Resume from
     checkpoint_path/training_state/ema_state.safetensors if present.
  2. After each update_policy() call, ema.update(unwrap_model(policy))
     returns the effective decay (logged to wandb during warmup).
  3. The save_checkpoint() block also ema.save(...) the shadow next to
     the existing optimizer/scheduler/rng training state. Resume picks
     it up automatically in (1).
  4. The eval block (cfg.env && is_eval_step) wraps eval_policy_all in
     ema.apply_to() when use_for_eval=True. Live weights restored
     byte-for-byte on context exit.
  5. The W&B training-example dump wraps log_training_examples in
     ema.apply_to() when use_for_wandb_examples=True so the predicted-
     action columns match the eval/deployment behavior.
  6. Two new wandb scalars: ema/effective_decay, ema/num_updates.

Cost:
  Memory: 1x model params in fp32 (~13 GB for pi052's 3.3B params).
          Lives only on main-process GPU. CPU offload available via
          ModelEMA(device='cpu') if needed.
  Compute: one elementwise update per step (~1% of step time).
  Eval: 2x checkpoint files in training_state/ (live optimizer state
        + ema shadow). Negligible relative to model.safetensors.

Usage:
  lerobot-train ... --ema.enable=true
  lerobot-train ... --ema.enable=true --ema.decay=0.9999  # very slow EMA
  lerobot-train ... --ema.enable=true --ema.warmup_steps=1000

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 21:27:14 +02:00
Pepijn 738e317caa pi052: PaLM-style z-loss on text CE (default weight 1e-4)
Penalise the log-partition function z = log Σ exp(logits) drifting away
from zero on text-CE supervised positions. Without it, large-vocab
models (PaliGemma's 257k vocab) can let logsumexp grow unboundedly
while CE stays low — a uniform additive logit bias cancels in softmax
but pushes the partition function out of bounds, causing numerical
instability and generation drift.

PaLM appendix B / Chinchilla report z-loss is essential for stable
large-vocab CE. It is especially valuable for pi052 because the recent
default lm_head_lr_scale=5.0 amplifies head-drift risk: the 5x boost
keeps the head pinned to fine-tuning targets, and z-loss caps the
partition function so the head can't just bias all logits high uniformly.

Implementation:
  * _shifted_ce(logits, labels, z_loss_weight=0.0) gains the new arg
    with default 0.0 (back-compat for any other caller).
  * Both call sites in PI052Policy.forward read self.config.text_ce_
    z_loss_weight and pass it through.
  * PI052Config.text_ce_z_loss_weight defaults to 1e-4 (commonly cited
    PaLM value); set to 0 to disable.

Cheap to compute: one extra logsumexp shares the softmax kernel that
F.cross_entropy already runs. No memory overhead beyond a (B*T,) tensor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 21:08:56 +02:00
Pepijn 8ba3b187a1 pi052: bump lm_head_lr_scale default to 5.0 (keep base LR at 2.5e-5)
The base optimizer LR (2.5e-5, cosine to 2.5e-6, 1k warmup, AdamW
(0.9, 0.95), wd 0.01, grad_clip 1.0) is the openpi/π0.5 setting used
for the RoboCasa leaderboard baselines and is well-validated for 3B-
class VLAs with a paligemma backbone. Leave it alone.

The one place pi052 needs to diverge from pi05 is the LM-head LR
multiplier:

  * pi05 has no text supervision -> head doesn't get gradients ->
    lm_head_lr_scale is moot, stays at 1.0.
  * pi052 always has text supervision via the recipe (subtask /
    memory / VQA). Under KI, the LM head only sees gradients on
    ~30-45% of the batch (the text-CE mask share). Under aggressive
    cosine decay the head drifts back toward PaliGemma's pretrained
    <loc> first-token bias, despite teacher-forced CE staying near 0.

5x is the documented fix (see PI05Config.lm_head_lr_scale docstring
and PI05Policy.get_optim_params, which is already wired to split the
LM head + tied embed_tokens into their own param group while sharing
the same cosine lambda). Flipping the default here lifts the fix from
opt-in to on-by-default for every pi052 run, with zero downside on
text-free recipes (head still gets no gradients to scale).

Other LR knobs reviewed and intentionally NOT changed:
  - optimizer_lr=2.5e-5: openpi-validated, matches leaderboard.
  - scheduler_warmup_steps=1000: standard for VLA finetuning.
  - scheduler_decay_steps=30000: auto-scales for short runs.
  - optimizer_betas=(0.9, 0.95): GPT/LLM convention, works for
    flow-matching + LM-CE.
  - optimizer_weight_decay=0.01, grad_clip=1.0: standard.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 20:57:43 +02:00
Pepijn 057c794ffe wandb: flip training-example logging defaults to on (every 5000 steps)
The training-example wandb.Table dump (camera images + text fields +
GT/predicted action chunk endpoints) was opt-in. Flip defaults so any
run with --wandb.enable=true gets visual training observability for free.

  log_examples_freq:           0     -> 5000   (push table every 5k steps)
  log_examples_n:              4     -> 4      (unchanged)
  log_examples_predict_actions: False -> True   (extra forward in eval mode)

Runs without --wandb.enable=true are unaffected (the training loop gate
checks wandb_logger is not None first). Set log_examples_freq=0 to opt
out of the dump even with wandb enabled; set log_examples_predict_actions
=false to skip the extra inference forward pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 18:00:04 +02:00
Pepijn b1e83f556c train: periodic wandb log of training examples (images + text + actions)
Adds an opt-in cadence for pushing rich training examples to W&B,
independent of the scalar log_freq. Off by default; turn on with
--wandb.log_examples_freq=5000 (one wandb.Table dump every 5k steps).

WandBConfig (configs/default.py):
  + log_examples_freq: int = 0       # 0 disables
  + log_examples_n: int = 4          # batch elements per dump
  + log_examples_predict_actions: bool = False
                                     # opt-in extra forward pass to
                                     # show predicted vs GT action chunk

WandBLogger.log_training_examples (common/wandb_utils.py):
  Builds one wandb.Table row per sampled batch element with:
    * one wandb.Image column per camera (auto handles CHW/HWC,
      uint8/float32 [0,1])
    * any text fields present in the batch (task / subtask /
      memory / instruction)
    * gt_action_first / gt_action_last (chunk endpoints)
    * pred_action_first / pred_action_last when --wandb.log_examples_
      predict_actions=true (policy.eval() + no_grad; restores train
      mode after)
  Defensive: per-camera failures don't poison the row; predict_action_
  chunk exceptions are logged and the predicted columns are dropped.

Training loop (scripts/lerobot_train.py):
  One new gated block right after the existing scalar log_step clause.
  Reads batch + dataset.meta.camera_keys, hands them to
  log_training_examples. Wrapped in try/except so a bad sample never
  kills the run.

Usage:
  lerobot-train ... \
    --wandb.enable=true --wandb.project=robocasa_composite_seen \
    --wandb.log_examples_freq=5000 \
    --wandb.log_examples_n=4 \
    --wandb.log_examples_predict_actions=true

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 16:57:15 +02:00
Pepijn da3e87ee86 Merge branch 'feat/smolvla-on-steerable' of https://github.com/huggingface/lerobot into feat/smolvla-on-steerable 2026-05-25 16:56:50 +02:00
Pepijn 1e9a6d044d Merge remote-tracking branch 'origin/feat/language-annotation-pipeline' into feat/smolvla-on-steerable
# Conflicts:
#	src/lerobot/datasets/__init__.py
#	src/lerobot/policies/__init__.py
#	src/lerobot/policies/factory.py
#	src/lerobot/processor/render_messages_processor.py
#	uv.lock
2026-05-25 16:56:22 +02:00
pepijn 3fdfcb912a examples(port_datasets): generalize RoboCasa builder + add smoke script
- Add ATOMIC_TASKS, COMPOSITE_UNSEEN_TASKS and four new --task-set keys
  (atomic, composite_unseen, composite_all, composite_atomic) so the same
  builder produces the 50-task target benchmark or the 300-task Human300
  pretraining slice (via --split=pretrain --task-set=all) without
  duplicating logic.
- Stop hardcoding the composite_seen tag on the HF push; tags are now
  derived from --split / --source / --task-set so atomic, composite_all,
  and pretrain runs land with accurate metadata.
- Refresh module docstring to match the broader scope.
- Add scripts/build_robocasa_smoke.sh: 2-atomic-task smoke dataset
  (~1k episodes, ~131k frames) for fast end-to-end training validation
  before kicking off Human300-scale runs.
2026-05-25 14:54:00 +00:00
Pepijn c37b1fc7d0 Merge origin/feat/language-annotation-pipeline (8 fix(annotate) commits + vocabulary phase) 2026-05-25 15:47:25 +02:00
Pepijn 9020635b14 Merge branch 'main' into feat/language-annotation-pipeline
Resolves conflicts from 32 commits on main:

* docs/source/_toctree.yml — keep both new toc entries
  (annotation_pipeline + video_encoding_parameters).
* docs/source/language_and_recipes.mdx — adopt main's section
  ordering (Layer 2 before "Temporal semantics") and float32
  timestamp dtype to match the codebase.
* src/lerobot/configs/__init__.py — keep both export sets
  (recipe + video encoder).
* src/lerobot/datasets/dataset_metadata.py — drop redundant lazy
  imports (top-level imports cover both LANGUAGE_COLUMNS and
  DEFAULT_TOOLS); adopt main's @tools.setter for info.json
  write-back.
* src/lerobot/datasets/feature_utils.py — call the real
  validate_feature_language() instead of returning "".
* src/lerobot/datasets/language.py — float32 timestamps to match
  pa.float32() used in video_utils.py and the rest of the codebase.
* src/lerobot/datasets/language_render.py — adopt main's
  unwrap_scalar() helper (drops two hand-rolled .item()/list
  unwrappers); float32 in docstring.
* src/lerobot/processor/render_messages_processor.py — drop
  PR-local _scalar() helper, use shared unwrap_scalar().
* tests/datasets/test_language.py — adopt main's new float32 dtype
  + validate_feature_language warning tests.
* tests/datasets/test_dataset_metadata.py — adopt main's new
  tools.setter persist/clear tests.
* uv.lock — regenerated cleanly from main's resolver.

90 of 92 touched tests pass. Two pre-existing test failures
(test_module1_plan_memory_subtask_smoke,
test_module2_mid_episode_emits_paired_interjection_and_speech in
tests/annotations/test_modules.py) are unrelated to this merge —
that test file doesn't exist on main, so the failures originate on
the branch and are addressed by the 8 newer fix(annotate) commits
already on origin that will land in a follow-up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 15:46:32 +02:00
Pepijn 83d0c390da pi052: drop debug scaffolding left over from training/inference bug hunts
Three diagnostic surfaces shipped in PR3 that don't belong in a clean
release:

* ``LEROBOT_DUMP_RECIPE_SAMPLES`` env-var dump (~70 LOC in
  text_processor_pi052.py): pretty-prints the next N rendered samples
  with ``[TGT]...[/TGT]`` markers over supervised spans. One-off
  training-inspection tool — no production user, never wired into a
  CLI flag, only useful while iterating on the recipe. Drop the module
  constants, the ``_is_dump_rank`` / ``_dump_recipe_sample`` helpers,
  the call site, and the now-unused ``import os``.

* ``_log_obs_tensors_once()`` in lerobot_pi052_runtime.py: the
  docstring literally says "Used to bisect train/inference mismatches"
  — a debugging artifact from when the LM head was collapsing on the
  live robot. Logged unconditionally at WARNING level from both the
  dataset-driven and robot-driven providers, with no ``--verbose``
  gate. Drop the function, both call sites, and the ``_logged`` /
  ``_obs_logged`` flag dicts that fed them. (``_resize_logged`` is
  kept — it gates the operationally useful camera-size sanity log.)

* Defensive ``unsqueeze(0)`` block in the dataset observation
  provider: papered over an upstream bug where some preprocessor step
  could produce an unbatched tensor. ``AddBatchDimensionProcessorStep``
  is reliable in the current pipeline — pi052 tests still pass with
  the block removed. If the bug ever resurfaces it should be fixed
  at the source, not silently re-batched here.

Net: -169 LOC. All 30 ``tests/policies/pi052/`` tests pass.

The ``<loc>`` token plumbing (``register_paligemma_loc_tokens``,
``_loc_token``, ``suppress_loc_tokens`` runtime gate) is left as-is —
it's the actual mechanism for VQA spatial answers, not scaffolding,
and the ``suppress_loc_tokens=True`` callers on subtask/memory/
interjection paths and ``=False`` on the VQA path are intentional
asymmetric behaviour, not a bug-routing knob.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 15:07:43 +02:00
Pepijn 1ff10b935c Merge branch 'feat/language-annotation-pipeline' into feat/smolvla-on-steerable
Resolves conflicts from 66 commits on the base branch:

* pyproject.toml — keep base's transformers>=5.4.0,<5.6.0; add the
  sentencepiece-dep entry pi052 (FAST action tokenizer) needs.
* policies/__init__.py — keep pi052 export; drop the
  RewardClassifierConfig export that base removed.
* policies/factory.py — docstring list resolution (keep pi052; drop
  reward_classifier, removed by base).
* annotations/steerable_pipeline/executor.py — adopt base's renamed
  _ensure_annotation_metadata_in_info (it already advertises the say
  tool); drop pi052's older _ensure_tools_in_info call.
* configs/train.py — keep pi052's vqa_target_fraction; adopt base's
  SampleWeightingConfig (legacy RA-BC inline params already covered
  by the migration shim base added).
* scripts/lerobot_train.py — merge pi052's per-policy processor
  rebuild + dataset_repo_id pass-through with base's active_cfg /
  is_reward_model_training tightening, and re-route vqa-weighted
  sampler to active_cfg.drop_n_last_frames.
* datasets/language_render.py — adopt base's _select_one + timestamp
  tolerance (drops pi052's stale _select_latest / per-style sort_key).
* tests — adopt base's parametrized per-camera blend + tolerance
  test; drop pi052 tests that overlap with base's tighter rewrites;
  keep pi052's flow-only / VQA-blend coverage; add a
  test_canonical_recipe_loads check on subtask_mem_vqa_speech.yaml.
* policies/pi052/processor_pi052.py — import RenderMessagesStep
  directly from render_messages_processor (base intentionally
  dropped it from lerobot.processor's re-exports).
* uv.lock — regenerated cleanly from base + pi052's pocket-tts /
  beartype.

All 67 touched tests pass (30 pi052 + 37 recipe / language-render /
pipeline / render-messages).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 14:47:09 +02:00
Pepijn 67bdf4690e examples(port_datasets): rewrite RoboCasa composite_seen builder
Replace the earlier wrapper (which depended on robocasa.scripts.download
+ dataset_registry) with a self-contained pipeline that:

* downloads each task tarball directly from Box via box_links_ds.json
* converts v2.1 -> v3.0 in place using convert_dataset_v21_to_v30
* standardizes camera keys under observation.images.robot0_* and
  flattens observation.state by concatenating base/EE/gripper subkeys
  when the source dataset stores them separately
* builds per-rank unified shards then aggregates into one dataset

Filter: composite_seen task-set restricts discovery to the 16 multi-step
target tasks (DeliverStraw, GetToastedBread, ..., WashLettuce). Use
--task-set=all to keep every discovered task in the split/source slice;
--tasks=... overrides for arbitrary subsets.

Defaults sized for hopper-cpu @ 128 cores: 16 workers x 8 cpus-per-task.

Adapted from a battle-tested port_robocasa.py reference shared by the
user; the only semantic addition is the task-set filter.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 14:27:42 +02:00
Pepijn 8085feab6e pi052(runtime): factor out shared observation-prep boilerplate
Both observation providers in lerobot_pi052_runtime.py ended a sample
dict the same way — strip the runtime-owned language columns and hand
the policy a device-resident ``observation.*``-only subset. Extract
two tiny helpers (``_strip_runtime_owned_language_cols`` and
``_select_observation_to_device``) so the dataset and robot paths
read as a clear linear pipeline. Path-specific concerns (defensive
unsqueeze on the dataset path; camera resize + state-vector sanity
logging on the robot path) stay inline at the call sites.

Behaviour unchanged; all 30 ``tests/policies/pi052/`` tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 14:25:08 +02:00
Pepijn a088c10c80 examples(port_datasets): SLURM+datatrove RoboCasa composite_seen build
Parallel variant of build_robocasa_composite_seen.py modeled after the
existing slurm_port_shards.py / slurm_aggregate_shards.py pattern.

Two-phase datatrove pipeline:
  * Phase 1 DOWNLOAD: tasks=16 (one per RoboCasa composite_seen task),
    each worker downloads its assigned tar via RoboCasa's own
    download_datasets helper. Network-bound, idempotent.
  * Phase 2 AGGREGATE: tasks=1, single worker calls aggregate_datasets
    over the 16 extracted directories. Submitted with depends=phase1 so
    SLURM only releases it once all 16 downloads succeed.

Reuses the COMPOSITE_SEEN_TASKS list and per-task download/resolve
helpers from the single-machine script via aliased imports — single
source of truth for 'what does it mean to download a composite_seen
task'.

Local (--slurm 0) mode runs the two phases sequentially in-process for
debugging on a workstation.

Usage on SLURM:
    uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \
        --output-dir=/scratch/${USER}/robocasa_composite_seen \
        --hub-repo-id=${HF_USER}/robocasa_composite_seen \
        --logs-dir=/scratch/${USER}/logs/robocasa \
        --partition=cpu --push-to-hub

Prereq: uv sync --extra annotations  (pulls datatrove)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 14:10:05 +02:00
Pepijn 9c3d5ab7ce scripts: build_robocasa_composite_seen — aggregate 16 target tasks
RoboCasa 1.0 ships its target/human demos in LeRobot format (parquet +
mp4) as lerobot.tar archives distributed via Box. This script wraps
RoboCasa's own download_datasets helper to pull each of the 16
composite_seen tasks, opens each extracted directory as a
LeRobotDataset, and merges them into a single combined dataset via
merge_datasets (a thin wrapper over aggregate_datasets that revalidates
fps/robot_type/features, unifies task indices, concatenates videos and
parquet, and recomputes stats).

The 16-task slice corresponds exactly to the 'Composite-Seen' column of
the published RoboCasa365 leaderboard, so the resulting dataset is the
right substrate for an apples-to-apples pi05 vs pi052 comparison on
multi-step kitchen manipulation.

Usage:
    uv run python -m lerobot.scripts.build_robocasa_composite_seen \
        --output-dir=/data/lerobot/robocasa_composite_seen \
        --hub-repo-id=${HF_USER}/robocasa_composite_seen \
        --push-to-hub

Idempotent: re-running skips already-downloaded tasks. Defensive
fallbacks handle RoboCasa API drift in get_ds_path / download_datasets.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 14:01:28 +02:00
Pepijn e84f97a8c1 smolvla2(runtime): interactive task picker + drop action diagnostic
Task picker:
The dataset bootstrap used to silently overwrite args.task with the
canonical training task. Replace that with an interactive picker
(_select_task_interactively) that shows every unique task in
ds_meta.tasks as a numbered menu (canonical task first as default) plus
a 'type a custom task' option. --task on the CLI still skips the
picker, and non-TTY runs fall back to the bootstrap task so scripted
invocations are unchanged.

Action diagnostic removal:
Drop the [act] log block in LowLevelForward.run (|a|_mean / spread /
normalized + unnormalized first/last + state) that was added while
debugging the 'barely moving' issue. Robot motion is now healthy, the
output is noise in steady-state, and it depended on stashing the
postprocessor on runtime.state — also removed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 12:59:08 +02:00
Pepijn 6d2b8c80ab smolvla2(runtime): wire MemoryUpdateFwd into the inference pipeline
MemoryUpdateFwd was importable but never installed, so subtask_change
events fired by HighLevelSubtaskFwd had no listener and current_memory
stayed at its initial None value — the runtime panel always showed
'memory (not set)' even when the policy was trained with the
memory_update recipe (e.g. subtask_mem_vqa_speech.yaml, weight 0.15).

Insert MemoryUpdateFwd between HighLevelSubtaskFwd and AskVQAFwd so
the event is visible the same tick it is emitted, and refresh the
stale comment that claimed memory was not in scope.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 12:52:44 +02:00
Pepijn 793c7c4ddd feat(runtime): --subtask_chunks_per_gen throttles HL gen vs action chunks
Adds a per-chunk-boundary counter to HighLevelSubtaskFwd: subtask gen
fires only once every N chunk boundaries (default 1 = current
behavior). Lets the operator run e.g. 5 flow-matching action chunks
per LM-head subtask gen so the subtask doesn't churn every 1.7s while
the previous one is still being executed — saves compute and avoids
re-planning the action trajectory mid-grasp.

  --subtask_chunks_per_gen=5    # 5 chunks per subtask refresh

The counter starts at 0 so the very first chunk boundary fires
immediately (no startup delay). Trigger is rearmed when skipping so
a low high_level_hz doesn't lose slots.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 12:34:59 +02:00
Pepijn db927ab40b feat(runtime): action chunk diagnostic — log normalized + unnormalized values
Adds a per-chunk log line in LowLevelForward that surfaces what the
action expert actually emits and what the robot receives after the
postprocessor unnormalizes it, so "barely moving" can be diagnosed
at a glance:

  [act] T=50 |a|_mean=0.234 spread=0.512
  [act] norm  first=[0.12, -0.31, ...]  last=[0.45, -0.22, ...]
  [act] joint first=[3.2, -47.8, ...]  last=[12.4, -41.0, ...]  state=[0.5, -55.3, ...]

|a|_mean ~ 0.3–0.6 with spread ~ 0.3+ and visible delta from first to
last → healthy trajectory. |a|_mean near 0 across the chunk → model
defaulting to median pose. joint values that don't differ much from
state → safety cap or model output near current state.

Postprocessor is stashed on runtime.state["_postprocessor"] at startup
so the diagnostic can replay the same unnormalize the dispatcher uses.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 12:10:52 +02:00
pepijn 471b2b1b1d fix(annotate): bump same-frame subtasks onto distinct frames
If two consecutive VLM-emitted subtask spans have ``start`` timestamps
that round to the same source frame after ``snap_to_frame`` (e.g. on
short episodes the VLM sometimes nominates two ~adjacent action
boundaries within one 30 Hz step), the writer emits two
``style=subtask`` rows at the identical persistent timestamp. The
training-time renderer's default binding
``subtask: active_at(t, style=subtask)`` then raises:

    ValueError: Ambiguous resolver for style='subtask';
                add role=..., tool_name=..., or camera=... to disambiguate.

… and the whole training run dies on the first batch.

Observed concretely on ``pepijn223/super_poulain_vocab2`` (job
22159979): episodes 3 and 30 each had two subtask rows at the same
timestamp (``release yellow cube`` + ``retract arm`` snapping to the
same frame).

Add ``_dedupe_starts_to_distinct_frames`` to walk the cleaned span list
and, whenever a snapped start collides with one already used, push the
later span onto the next free frame timestamp. Both subtasks survive
on distinct timestamps; the renderer can now disambiguate. If the
episode genuinely has no later free frame (extremely unlikely — would
require a same-timestamp collision on the very last frame of the
episode), the later span is dropped with a warning rather than left
to poison the render.

New test ``test_plan_module_bumps_collocated_subtasks_to_distinct_frames``
locks in the contract; full vocabulary suite is 14/14 green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-23 19:31:44 +00:00
pepijn a15e16c072 fix(annotate): replace fuzzy subtask snapping with strict match + one-shot retry
The Jaccard-overlap snap was warping VLM output into wrong canonical
labels — e.g. an off-vocab "consult the wizard" span would silently
become "grasp blue cube" if that scored highest. Even with a higher
floor the operator can't tell which subtasks were paraphrases vs
genuine mislabels in the resulting dataset.

Replace with strict exact-match validation + a single targeted retry:

  1. Generate subtasks as before.
  2. If any returned subtask's normalised form (lowercased, articles
     stripped, whitespace collapsed) isn't in the canonical vocab,
     fire one retry call naming the offending strings and re-sending
     the full canonical list. The retry prompt requires byte-identical
     output from the vocab.
  3. After the retry, validate again. Spans still off-vocab are
     dropped — no fuzzy snapping ever produces a different canonical
     label than the VLM actually emitted.
  4. If every span ends up off-vocab even after the retry, warn loudly
     so the operator extends ``meta/canonical_vocabulary.json`` to
     cover the missing phase. The episode is left with empty subtasks
     rather than silently fabricated ones — visibility > sweep-under-
     the-rug.

Promote ``_NORMALIZE_STRIP_TOKENS`` to a class constant and split the
normalisation helper out so the retry-validation and the final
canonicalisation share one source of truth.

Tests:
  - test_plan_module_accepts_article_only_difference: "grasp the blue
    cube" still maps to canonical "grasp blue cube" (article-tolerant).
  - test_plan_module_retries_when_subtask_off_vocab: paraphrase
    triggers the retry which the VLM corrects in pass 2.
  - test_plan_module_drops_off_vocab_subtask_after_retry: VLM that
    refuses to correct → bad span dropped, in-vocab span kept.
  - test_plan_module_empty_when_all_off_vocab_after_retry: every
    span off-vocab → episode left empty (no warping).
All 13 vocabulary tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-23 09:57:27 +00:00
pepijn 336af85c09 fix(annotate): never leave an episode with zero canonical subtasks
When the canonical vocabulary is enabled and the VLM produces spans
that don't overlap any canonical label, the previous Jaccard-floor
(0.5) dropped them and the episode came out with no subtasks at all
— invisible to the downstream policy. Observed on
``pepijn223/super_poulain_vocab``: some episodes had empty subtask
columns because every VLM-emitted phrase scored below 0.5 against
the discovered vocabulary.

Two-pass canonicalisation:

  - First pass keeps the Jaccard floor (lowered from 0.5 → 0.25, to
    let mild paraphrases through) and drops everything below.
  - If that first pass leaves the episode with **zero** subtasks,
    fall back to a second pass that always snaps each VLM span to
    its nearest canonical label by Jaccard (no floor). The episode
    ends up with subtasks even when the vocabulary missed a phase
    — a slightly-wrong canonical label is still closer to the right
    motion than nothing at all.
  - Log loudly when the fallback fires so the operator can spot
    coverage gaps in ``meta/canonical_vocabulary.json``.
  - Log a per-episode count at INFO when some (but not all) spans
    were dropped so it's visible without spamming the run output.

Promote the Jaccard floor + ignore-tokens to class constants so
they're a single edit point. Add ``force=True`` parameter to
``_canonicalize_subtask`` for the no-floor fallback path.

New test ``test_plan_module_snaps_when_all_off_vocab`` covers the
fallback; existing ``test_plan_module_drops_off_vocab_subtask`` is
adjusted to keep at least one in-vocab span so the floor path can
still fire and is exercised. All 12 vocabulary tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-22 12:44:03 +00:00
pepijn 54221ceea2 feat(annotate): let the VLM decide vocabulary size
Hardcoding ``n_subtask_target=10`` and ``n_memory_target=6`` baked task
complexity into the config — a simple pick-and-place needs ~6, a
multi-step recipe needs ~20. The VLM already sees the clips, so let it
pick the count itself from what's recurring across episodes.

Drop both knobs from ``VocabularyConfig`` and the ``module_0_vocabulary``
prompt template. The prompt now says "decide the count yourself based
on what you see — the smallest set that still covers every recurring
phase" and adds an "each label must recur across the demos" rule so
the VLM filters out one-off motions.

Update the launcher script + docs to remove the old knobs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-22 11:46:31 +00:00
pepijn 369ab17110 fix(annotate): update run_hf_job CLI args for renamed namespaces + phase 0
Three stale things in the launcher script:

  - ``--module_1/2/3.*`` no longer exist; review commit fd18beb renamed
    the CLI namespaces to ``--plan/interjections/vqa``. Forwarded all
    eight existing args to their new names.
  - ``--push_to_hub`` is now a bool; the destination repo lives at
    ``--dest_repo_id``. Split the single positional into both args.
  - ``openai`` was missing from the pip install list, which the prior
    review review (claude bot, 2026-05-08) flagged — the default vlm
    backend is ``openai`` so the job would have ImportError'd. Added.

Also expose the new phase 0 (canonical vocabulary discovery) knobs
explicitly: ``--vocabulary.sample_episodes``, ``--n_subtask_target``,
``--n_memory_target``. Defaults are sane (3 / 10 / 6) but worth
flagging in the example so the operator knows what they're running.

Update the docstring + section comments to match the current phase
layout (vocabulary → plan → interjections → vqa → writer).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-22 11:43:06 +00:00
pepijn 86a7edc590 feat(annotate): phase 0 — derive canonical vocabulary from sample episodes
The pipeline previously emitted near-unique subtask + memory phrasings
per episode (free-form LLM rephrasing). On the downstream low-level
policy that collapses the action expert's conditioning to noise: every
episode pairs a different paraphrase with similar motions, so the
expert learns a flat scene-prior that ignores the subtask string —
then at inference the high-level head invents *yet another* paraphrase
and the expert produces tiny "uncertain hover" chunks.

Add a vocabulary-discovery phase (phase 0) that runs once per dataset:

  - watches the first ``vocabulary.sample_episodes`` (default 3)
    episode videos as one Qwen-VL prompt,
  - asks the VLM to derive ~``n_subtask_target`` canonical imperative
    subtask labels and ~``n_memory_target`` first-person past-tense
    memory milestones that recur across the demos,
  - persists them to ``meta/canonical_vocabulary.json`` (human-
    inspectable, hand-editable), and
  - wires the resulting ``Vocabulary`` into the ``plan`` module so
    every per-episode subtask + memory call is constrained to those
    exact strings (both as prompt-side instructions *and* post-VLM
    validation: paraphrases snap to the closest canonical entry via
    token-set overlap; below a 0.5 Jaccard floor the subtask is
    dropped rather than warped into something semantically wrong).

Operator workflow:

  - first run discovers the vocabulary, writes the JSON, and runs
    the ``plan`` module against it,
  - subsequent runs reuse the on-disk file (``reuse_existing=True``
    default) so hand-edits stick,
  - set ``--vocabulary.enabled=False`` to fall back to free-form
    generation (the original behaviour).

The discovery prompt forbids gerunds / third-person / adverbs and
caps the lists to the requested counts, matching the Hi-Robot /
π0.6-MEM convention of small per-environment vocabularies. The
``plan`` module's subtask + memory prompts grow a conditional
``{vocabulary_block}`` slot rendered only when a vocabulary is
present; without one the templates collapse to their previous
free-form form.

Tests: 11 new unit tests under tests/annotations/test_vocabulary.py
cover the on-disk round-trip, discovery against the fixture dataset,
``reuse_existing`` short-circuit, paraphrase canonicalisation, off-
vocab subtask dropping, and the no-vocabulary pass-through path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-22 11:40:05 +00:00
pepijn 77a16db529 fix(smolvla2): make HighLevelSubtaskFwd actually fire at low hz + quiet startup log
Two runtime fixes that surfaced from on-robot testing.

(1) HighLevelSubtaskFwd was double-gated: HzTrigger fires every period
(e.g. every 5s at --high_level_hz=0.2) AND the step requires the
action queue to be empty. The queue-empty window is brief (~tens of
ms between drain and refill) and almost never coincides with the
low-hz timer, so HL effectively never fired and the subtask shown
in the runtime panel stayed on the dataset's frame-0 annotation.

Add HzTrigger.rearm() and have HighLevelSubtaskFwd call it when
skipping due to queue-non-empty — the trigger stays armed and tries
again on the next tick instead of waiting another full period.
LowLevelForward keeps the original "skip" semantics because chunk_hz
is meant as a true upper bound on chunk-generation rate.

(2) The "robot state at startup" warning in _build_robot_observation_provider
was meant to fire once but wasn't gated by _resize_logged like the
sibling "camera ... live=AxB" warning. Result: it spammed every
observation tick (~1-2s). Gate it on first_call (snapshot of
_resize_logged["done"]) so both logs fire once at session start.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-22 11:04:12 +00:00
pepijn ca1b951e7b feat(pi05): expose lm_head_lr_scale for stronger text-CE gradient
With knowledge_insulation=True the LM head only receives gradients on
text-CE samples (e.g. ~45% of the mix for subtask_mem.yaml). Under
aggressive cosine LR decay this is enough for the head's first-token
distribution to drift back toward PaliGemma's pretrained <loc>
detection prior — teacher-forced argmax stays high while autoregressive
generation collapses to <locDDDD> tokens.

Add `lm_head_lr_scale` (default 1.0, no behavior change) on PI05Config.
When != 1.0, PI05Policy.get_optim_params splits the policy into two
param groups: the PaliGemma lm_head projection plus its tied
embed_tokens at lr * lm_head_lr_scale, and the rest at lr. The cosine
scheduler multiplies both groups by the same lambda each step, so the
ratio is preserved across decay.

Recommended starting point for pi052 + subtask_mem.yaml runs: 5.0,
combined with a higher scheduler_decay_lr floor (e.g. 5e-6 instead of
1e-6) so the head doesn't get starved in the second half of training.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-22 09:56:46 +00:00
pepijn 9d30d91021 fix(pi052,smolvla2): unblock text generation when LM head drifted to <loc>
PaliGemma's pretraining puts heavy first-token mass on its <loc0000>..
<loc1023> ids at any "Assistant:" continuation. Our pi052 fine-tunes
with knowledge_insulation=True and a small text-CE budget (~45% of
samples) drift back toward that prior on long runs at low LR — teacher-
forced argmax stays at 100% (CE only measures next-token given correct
prefix) while autoregressive first-token selection collapses onto <loc>.
On the running poulain11 checkpoint at step 8000 this manifests as a
stream of <locDDDD> tokens for every subtask call — confirmed locally
against the saved checkpoint on a dataset frame.

Add a `suppress_loc_tokens` knob to `PI052Policy.select_message` that
masks ids [256000, 257024) to -inf before sampling, and pass it from
the three text-only inference steps (HighLevelSubtaskFwd,
MemoryUpdateFwd, UserInterjectionFwd). VQA steps keep the default
False so spatial answers can still emit locs. Verified end-to-end:
suppressed → "the robot arm moves the blue block to the green basket".

Also fix `_msgs_for_memory`: it was emitting the older
`User: ${task}\nPlan:..\nMemory:..` / `Assistant: ${subtask}` template,
which no longer matches the `memory_update` recipe layout
(`User: ${task}` / `Assistant: Previous memory: ..` /
`User: Completed subtask: ..`). The new prompt mirrors the training
recipe; `HighLevelSubtaskFwd` stashes the just-completed subtask in
`state['prior_subtask']` so the memory prompt can render
`Completed subtask: ..` for `MemoryUpdateFwd`.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-22 09:50:14 +00:00
pepijn e050d0fe0a fix(recipes): use active_at for memory_update, rebalance subtask_mem
memory_update was bound to `emitted_at(t, style=memory)`, which requires
the frame's exact timestamp to match a memory annotation. Memory rows are
placed at subtask-boundary timestamps and at 30 fps that's ~1% of frames,
so 99% of memory_update draws couldn't render and silently fell through
to _fallback_low_level_render — injecting task-conditioned low-level
training on ~30% of samples (subtask_mem.yaml).

Switch to `active_at`. At inference `MemoryUpdateFwd` is triggered on
`subtask_change` events, but the model only needs to learn the stateless
mapping (prior_memory, completed_subtask) -> current_memory. active_at
supervises this mapping on every frame inside a subtask interval, against
varied observations; the trigger lives outside the model. Net effect:
memory_update renders on ~87% of frames, the fallback leak drops from
~30% to ~4%, and memory CE gets a meaningful (not 0.3%) training share.

subtask_mem.yaml: rebalance to 0.30 / 0.55 / 0.15 so memory CE is
~13% effective and the freed weight goes to low_level_execution.
subtask_mem_vqa_speech.yaml: keep weights (memory_update=0.10 was
already balanced against the other text-CE branches).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-21 14:53:13 +00:00
pepijn 2ca030fa28 fix(pi052): build processors from current config
When fine-tuning from pi05_base, reuse only the pretrained weights so pi052 still generates recipe text labels and FAST action labels.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-21 13:54:29 +00:00
pepijn 36f828221c fix(pi05): preserve pretrained paligemma lm head
Keep the PaliGemma LM head in float32 and initialize it from pretrained weights or token embeddings when loading pi05 checkpoints.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-21 13:25:24 +00:00
Pepijn d41d874581 fix(pi052): debug parity harness truncates prompt instead of masking
The parity check in debug_text_predictions was producing false ✗
DIVERGED reports. Root cause: I built the "inference" batch by
zero-masking the attention past the supervised span, but kept the
full 512-token padded sequence. select_message reads the prompt-end
hidden state via ``vlm_out[:, -1:]`` — the LAST position of the
prefix — which in a padded batch is a padding-token hidden state,
not the last prompt token. PaliGemma's prior on those padded
positions reliably argmaxes to <loc0879>, falsely flagging a
training/inference mismatch.

Fix: truncate both tokens AND mask to length == first_sup before
calling select_message, mirroring what the real runtime does
(``tokenizer(prompt)`` returns un-padded ids). Now the parity check
compares like-with-like.

The actual training argmax in the dump was sensible English
("' move the blue cube into the green bin'" at acc=6/9) — the head
is learning correctly. The "<loc>" salad was purely the harness
reading from the wrong position.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 15:09:36 +02:00
Pepijn efa05f0ada fix(train): unwrap DDP policy in debug_text_predictions hook
At training time the policy is wrapped by Accelerator/DDP into a
.module attribute and custom methods are NOT proxied through the
wrapper, so ``hasattr(policy, "debug_text_predictions")`` was False
and the periodic dump was silently no-op'ing. Walk through .module
indirection to reach the raw PI052Policy that defines the method.

Also surface why the dump didn't fire (no method / empty supervised
positions / generation error) so users can see what's blocking it
instead of staring at silence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 13:41:20 +02:00
Pepijn e98b6f726b feat(train): debug dump runs inference too, with parity check
Extends the periodic LM-head dump (LEROBOT_DEBUG_PREDS_EVERY) to ALSO
run select_message autoregressively on the same prompt prefix and show:

  prompt                          : '<bos>User: ... Assistant: '
  target  (ground truth)          : ' close the gripper ...'
  training argmax (teacher-fed)   : ' close the gri lift ...'  acc=12/15=80%
  inference (autoregressive)      : ' close the gripper around ...'
  first-token parity              : train=3387 (' close') vs infer=3387 (' close')  ✓ MATCH

The first-token parity check is decisive: training-side argmax at the
prompt-end position and inference's first generated token both compute
``argmax(lm_head(h_last_prompt))`` on identical context, so they MUST
match. Any divergence signals a training↔inference bug (mask, dtype,
KI routing, embedding scale, etc.). Subsequent tokens can diverge
because training uses teacher forcing while inference free-runs.

debug_text_predictions now also returns an ``inference`` list keyed
by sample, each entry carrying ``first_sup_pos`` and ``decoded``.
Limited to 24 new tokens per sample to keep the dump fast.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 12:27:32 +02:00
Pepijn f7747d02a9 feat(train): periodic LM-head prediction dump for live debugging
Adds an opt-in diagnostic that, every N training steps, dumps 5 batch
samples plus the LM head's argmax prediction at every supervised
position alongside the label and a ✓/✗ marker — the cheapest signal
for "is text training actually learning what we expect, or collapsing
to a fixed token". Refills the recipe-sample dump budget on the same
cadence so the raw input shapes are also re-dumped.

Opt in via env var:
  LEROBOT_DEBUG_PREDS_EVERY=1000 lerobot-train ...

PI052 implements ``debug_text_predictions`` (mirrors the text-loss
forward but returns argmax instead of CE); other policies are silently
skipped. The dump runs in eval() mode under no_grad, slicing the
current batch to N samples — no extra data fetch, no train-state
mutation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-21 12:23:05 +02:00
pepijn 86ecd4bc2e add subtask memory training recipe
Add a recipe that blends subtask prediction, low-level execution, and memory update supervision.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-21 09:56:10 +00:00
pepijn 28b86449a2 fix(pi05): cast attention masks to model dtype
Ensure attention masks follow the backbone dtype during bf16 inference to avoid mixed dtype failures.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-21 09:52:46 +00:00
Pepijn 5bb2da4da6 fix(pi052): VQA target format = "label <loc><loc>" not "<loc><loc> label"
The trained model collapsed to spewing 40+ <loc> tokens for *every*
prompt — subtask, memory, anything — because VQA targets were supervised
to *start* with <loc>. With ~25% of all text samples beginning with a
<loc> token, the LM head learned "Assistant: → <loc>" as a strong
attractor; once one loc is emitted, autoregression chains the rest.

Flip the format so every text target — subtask, memory, speech, AND VQA
— starts with a regular word. The model still learns the <loc>
vocabulary for the spatial portion of the answer, but loc can no
longer be the first generation step out of a clean prompt.

Examples:
  point  : "green box <loc0162><loc0759>"
  bbox   : "cube <loc0082>…<loc0409>"
  multi  : "blue <locs> ; yellow <locs>"

The runtime parser (parse_loc_answer) strips loc tokens and uses the
remainder as label, so it's order-tolerant and works under either
format. Old loc-first checkpoints still parse cleanly at inference;
new training will use label-first.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 18:56:48 +02:00
Pepijn f7b989ad97 fix(pi052): read backbone dtype from q_proj, not first parameter
select_message's bf16 cast used next(paligemma.parameters()).dtype,
which lands on a fp32-kept param (norm / embedding) under
to_bfloat16_for_selected_params. Mask stayed fp32 while q/k/v were
bf16 → SDPA still raised "invalid dtype for bias". Read the dtype
from layers[0].self_attn.q_proj.weight instead — q_proj is always
cast with the rest, so its dtype matches what SDPA sees.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 18:46:08 +02:00
Pepijn 3b4376aa33 fix(pi052): cast attention bias to model dtype for bf16 inference
`_prepare_attention_masks_4d` always returns fp32 (the 0.0 / -inf
literals); with bf16 weights, HF PaliGemma's SDPA path raises
"invalid dtype for bias - should match query's dtype" and
select_message returns empty every step. Cast in both attention
sites: `_compute_layer_ki` (training, when both experts run) and
`select_message` (inference, VLM-only branch). Bf16 training +
bf16 inference now run end to end with no dtype mismatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 18:42:26 +02:00
Pepijn a0233f53f4 feat(annotate): default VLM to Qwen3.6-35B-A3B-FP8
Match the production target used in examples/annotations/run_hf_job.py.
Per Scale Labs' dense-captioning ablations, model capacity dominates
prompt-engineering gains; defaulting to the larger model avoids
shipping a worst-tier configuration out of the box.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 11:46:59 +02:00
Pepijn 34269a5d78 fix(pi052): register PaliGemma <loc> tokens so they tokenize as single ids
THE bug behind the <loc>-salad. PaliGemma's vocab reserves ids
[256000, 257023] for <locDDDD> detection / pointing tokens, but the
stock AutoTokenizer does NOT match them on raw text — it BPE-splits
<loc0162> into SEVEN pieces (<, loc, 0, 1, 6, 2, >). So a VQA target
like "<loc0162><loc0759> green box<eos>" tokenized to 16 pieces, not
5, and training the LM head supervised those generic BPE pieces
instead of one detection-vocab id. The piece logits got pumped up
across ~25% of supervised positions; at inference they dominated
every turn — even subtask prompts produced <loc>-salad followed by
the actual answer.

Register the 1024 <locDDDD> tokens via tokenizer.add_tokens once on
load, in every path the policy uses: PI052TextTokenizerStep (training
encode), _build_text_batch_pi052 (runtime encode), and
select_message's default tokenizer (runtime decode). Verified
empirically with the real PaliGemma tokenizer: VQA target now
tokenizes to 5 ids matching the loc-vocab range (256162, 256759, ...)
with correct offset_mapping.

This unlocks PaliGemma's actual detection prior; <loc>-salad cannot
recur because each <locDDDD> is a single class on the LM head, not a
character sequence the head accidentally learns to extend.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 11:41:41 +02:00
Pepijn 75507491bf fix(pi052): VQA <loc> conversion treats coords as 0-1000 normalized
Confirmed empirically on the published dataset: VQA bbox/keypoint
coordinates are Qwen2.5-VL's 0–1000 normalized grounding output, NOT
pixels. Scanning 8207 samples showed x and y both spanning 0..1000
with ~30% of values exceeding the camera's pixel dimensions (which is
impossible if they were pixels).

_vqa_answer_to_loc was dividing by the observation image's H/W, so
e.g. point [742, 158] on a 640x480 wrist cam clamped x to <loc1023>
(the far-right edge) instead of mapping to <loc0760> (~74% across).
Fix: divide by 1000 — the actual Qwen scale. The conversion is now
camera-resolution-independent, so _camera_image_shapes and the
image_shapes plumbing through __call__ / _encode_messages /
_messages_vqa_to_loc are dropped. Tests updated to the new signature
and the 0–1000 round-trip.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 23:21:28 +02:00
Pepijn 88519cb14c fix(pi052): quantile-normalize actions before FAST tokenizer fit
base.fit() rejected the data with "Vocab size 1024 is too small for
the range of tokens 9339": the FAST tokenizer was fit on raw
motor-unit actions, whose DCT-token range vastly exceeds the 1024
codebook.

Two problems, one fix. (1) Raw actions blow up the token range. (2) At
training time ActionTokenizerProcessorStep runs after the QUANTILES
NormalizerProcessorStep, so it encodes normalized actions — fitting on
raw actions mismatches that space. Replicate QUANTILES normalization
(per-dim [q01,q99] -> [-1,1], clipped) before base.fit() so the fit and
the training-time encode see the same distribution and the token range
fits the codebook.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 23:02:20 +02:00
Pepijn bc0c993b25 fix(pi052): FAST tokenizer fit read actions from column, not ds[i]
fit_fast_tokenizer collected action chunks via ds[i]["action"], which
builds a full training item — delta-timestamp expansion, video decode,
image transforms. A single video-decode failure threw, was swallowed
at debug level, and silently starved the fit of every chunk → "FAST
fit collected zero action chunks", falling back to the universal
tokenizer.

Read the ``action`` column straight from the HF dataset instead: it
carries no video, so it is immune to decode errors and far faster.
Also fail fast with a clear message when the dataset has no ``action``
feature or all episodes are shorter than chunk_size.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 22:51:53 +02:00
Pepijn ddf4bc2063 fix(pi052): knowledge insulation crashed on wrong _gated_residual import
_compute_layer_ki called modeling_gemma._gated_residual, but that
adaRMSNorm gated-residual helper is a lerobot helper in pi_gemma, not
part of HF transformers — so enabling knowledge_insulation crashed with
AttributeError on the first training step. Import _gated_residual from
pi_gemma, matching pi05's own layer code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 22:48:02 +02:00
Pepijn b7317b6c29 test(pi052): round-trip coverage for VQA <loc> conversion
Pins JSON pixel coords -> PaliGemma <loc> -> runtime parse back: the
conversion preserves coordinate order (JSON x-first, <loc> y-first) and
per-axis normalization, losing only <loc>-grid quantization.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 22:24:24 +02:00
Pepijn c026aed8f8 feat(pi052): train VQA spatial answers in PaliGemma <loc> format
Spatial VQA answers (bbox / keypoint) were trained as pixel-coordinate
JSON, which fights PaliGemma's detection prior and leaks <loc>-token
salad at inference. Convert them to PaliGemma's native <locNNNN>
vocabulary instead so the LM head reuses that prior.

Training side (text_processor_pi052.py): a target turn whose content
parses as a bbox/keypoint answer is rewritten to <loc> text, using the
camera frame's native (H, W) from the observation and the preceding
image block. Non-spatial answers, subtask/memory targets and SmolVLA2
keep their JSON form — the dataset stays backbone-agnostic.

Runtime side (smolvla2/inference/vqa.py): parse_vqa_answer detects
<loc> answers (2 locs -> keypoint, 4 -> bbox), returning normalized
[0,1] coords with a normalized flag; draw_vqa_overlay denormalizes
against the chosen camera frame's pixel size.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 20:23:46 +02:00
pepijn e425dfd624 fix(processor): fallback to task message when recipe misses
Keep action-only samples trainable by rendering the task as a low-level user message when no recipe branch matches.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-19 15:32:09 +00:00
Pepijn 15f79b5e5e fix(pi052): supervise an EOS token at the end of each text target
PI052TextTokenizerStep masked text_labels over the assistant turn's
*content only* — the trailing newline was excluded and no EOS token was
ever a supervised label. So the LM head was never given a stop signal:
at inference select_message decoded to max_new_tokens, producing the
runaway subtask paragraphs and the "}"}"}-style VQA tails.

_format_messages now appends the tokenizer's EOS to each supervised
target turn and extends that turn's span to cover it, so the EOS lands
in text_labels. _shifted_ce then trains "<last content token> -> EOS"
and the model learns to terminate; select_message stops on it.

Inference callers (the runtime's _build_text_batch_pi052) pass no
target_indices / eos_token, so no EOS is baked into the prompt — the
model generates it. Verified end-to-end with the PaliGemma tokenizer:
the supervised span is `<content><eos>` and the trailing newline stays
unsupervised.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 17:22:22 +02:00
pepijn 2ea0da2d9f fix(annotate): tag uploaded dataset revision
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-19 12:44:35 +00:00
Pepijn 725ac95b0d feat(runtime): make the interactive runtime drive PI052 too
The runtime's text path was hard-wired to SmolVLA2: _build_text_batch
read policy.config.vlm_model_name (which PI052Config doesn't have) and
built a SmolVLM2 chat-template prompt. PI052/PaliGemma is not
chat-pretrained and trains on a flat `User: ... \nAssistant: ...`
prompt, so the runtime crashed or fed an out-of-distribution prefix.

- _build_text_batch now dispatches on policy.config.type: smolvla2 ->
  chat template (renamed _build_text_batch_chat); pi052 -> flat
  role-prefixed text via PI052TextTokenizerStep's own _format_messages /
  _strip_blocks / _flatten_say_tool_calls, so the inference prefix
  matches PI052 training exactly.
- Add a lerobot-pi052-runtime entry point (alias of the same main; the
  policy type is read from the checkpoint) so the command name isn't
  misleading. argparse prog now defaults to the invoked command name.

PI052's select_message / predict_action_chunk already work with the
runtime; this was the one SmolVLA2-only coupling.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:28:55 +02:00
Pepijn 7b64e5498d revert(annotate): move memory + speech prompts to base PR (#3471)
The first-person memory narrative, task-rephrasing and initial-speech
prompt tweaks belong in the annotation pipeline itself. Applied to
feat/language-annotation-pipeline (#3471); reverting them here to the
merge-base so they drop out of this PR's diff. general_vqa.py keeps its
docstring fix since it references a recipe this PR introduces.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:17:52 +02:00
Pepijn 134a707c7a feat(annotate): first-person memory narrative + shorter speech prompts
- module_1_memory: rewrite as an explicit first-person, past-tense
  narrative ("I picked up...", "I opened...") matching the MEM
  (Torne 2026) running-memory style, instead of "one or two short
  sentences" with no person/tense guidance.
- module_1_task_rephrasings: bias rephrasings toward short imperative.
- module_2_initial_speech: prefer very short robot acknowledgements.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:17:30 +02:00
Pepijn 182f10184f revert(annotate): move pipeline changes to base PR (#3471)
The deterministic-plan rewrite, single-frame VQA (K 3->1), dataset
version tagging, telegraphic-subtask prompt and shorter interjection
prompt belong in the annotation pipeline itself, not in the SmolVLA
training PR. They have been applied to feat/language-annotation-
pipeline (#3471). Reverting these six files here to the merge-base so
they drop out of this PR's diff; #3491 will inherit the canonical
versions when it next rebases on its base.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:07:23 +02:00
Pepijn ce47075d6b feat(annotate): deterministic plan, single-frame VQA, dataset tagging
Port the steerable-pipeline refinements developed on feat/smolvla-on-
steerable back into the annotation pipeline itself:

- module_1_subtasks: imperative verb-first telegraphic labels with a
  consistent-object-noun rule and good/bad examples (no hard word cap).
- _generate_plan: drop the VLM round-trip; the plan is now a
  deterministic numbered list of still-todo subtasks, re-emitted at
  every subtask boundary so it shrinks as work progresses. Removes
  module_1_plan.txt.
- VqaConfig.K 3 -> 1: a VQA pair anchors exactly its emission frame, no
  stale-label temporal smear.
- lerobot-annotate: tag the pushed dataset with its codebase_version so
  LeRobotDataset can resolve a revision and load it.
- module_2_interjection: shorter, more natural mid-task cues.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 14:06:15 +02:00
Pepijn 26013da699 feat(annotations): enforce imperative verb-first subtask phrasing
Rewrite module_1_subtasks prompt to produce short imperative commands
("pick up the orange") instead of third-person narration ("the robot
arm moves to the orange"). Drops the verbose "how, not what" rule and
adds a good/bad few-shot table.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 13:53:20 +02:00
pepijn bb31988915 fix(pi052): pass 4d masks to prefix-only forwards
Convert PI052 prefix-only attention masks before calling PaliGemma so text-only batches and generation use the same mask shape as fused training.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 21:07:13 +00:00
pepijn 2629175d2d fix(pi05): use fused AdamW by default
Route full PI05/PI052 fine-tuning through PyTorch's fused AdamW path to avoid the single-tensor Adam denominator allocation near GPU memory limits.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 19:23:17 +00:00
pepijn 2b4c5f49e3 fix(pi05): disable foreach AdamW by default
Avoid the multi-tensor AdamW temporary that can OOM full PI05/PI052 fine-tuning near GPU memory limits.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 18:58:17 +00:00
pepijn 22c9c4905e fix(pi052): avoid dense CE over padded tokens
Select only supervised text and FAST action-code positions before cross-entropy to avoid full-vocabulary loss tensors over padded sequences.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 18:40:34 +00:00
pepijn 7960cc14ec fix(pi052): call policy preprocessing helpers
Use PI05Policy helpers for action padding and image preprocessing in PI052 fused losses instead of looking them up on the inner PI05Pytorch module.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 17:52:47 +00:00
pepijn 1750a87104 fix(pi052): handle batched rendered messages
Tokenize batched recipe outputs in PI052 so training batches with nested message lists do not crash before model forward.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 17:41:58 +00:00
pepijn 0e2dc1b76f fix(pi052): supervise only FAST action-code tokens
Mask the FAST auxiliary loss to discrete action-code tokens so wrapper formatting tokens do not affect action co-training.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-18 17:38:34 +00:00
Pepijn 474c5478d9 tune(annotations): VQA emission anchors a single frame (K 3 -> 1)
Module 3 anchored each VQA emission tick to K=3 consecutive frames
(~0.1s at 30fps). The VLM grounds the answer — bbox/keypoint
coordinates especially — against the first frame's image, so copying it
onto frames 2-3 smears a stale label over a moving scene.

Default K=1: a VQA pair lands on exactly its emission frame, no
temporal smear. VQA frames get sparser; the WeightedEpisodeAwareSampler
(vqa_target_fraction) is the knob to compensate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 17:24:36 +02:00
Pepijn f72b28738a fix(annotate): default keyframe decode to ffmpeg CLI (thread-safe)
The decoder chain tried torchcodec first, then ffmpeg. torchcodec is
not thread-safe: under the executor's 16-wide concurrent decode in the
interjections phase it SIGSEGVs (exit 139) before the ffmpeg fallback
is ever reached — uncatchable, so it kills the whole job.

Default the auto chain to ffmpeg only. Per-frame ffmpeg decode runs in
an isolated child process: crash-safe and concurrency-safe (the plan
phase already proved 16 parallel ffmpeg subprocesses are fine).
torchcodec / pyav remain available via an explicit video_backend.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 16:40:29 +02:00
Pepijn 1bd53cc7da fix(annotate): decode keyframes via ffmpeg CLI fallback
PyAV segfaulted (exit 139) decoding the AV1 streams modern LeRobot
datasets use — a SIGSEGV that the per-episode try/except cannot catch,
killing the whole job when the interjections phase started.

Replace the PyAV fallback with _decode_frames_ffmpeg, which shells out
to the ffmpeg CLI: a full ffmpeg build decodes AV1, and a child-process
crash is a catchable non-zero exit rather than a segfault. Decoder chain
is now torchcodec -> ffmpeg. _decode_frames_av stays available behind
video_backend="pyav" for callers that want it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 16:08:31 +02:00
Pepijn 0f5f0e4091 refactor(recipes): rename recipes, drop pi05_hirobot
- hirobot.yaml            -> subtasks_vqa.yaml
- hirobot_memory.yaml     -> subtask_mem_vqa_speech.yaml
- pi05_hirobot.yaml       -> deleted (stale: uses plan, top-camera names;
  superseded by the two recipes above)
- smolvla2_hirobot.yaml   -> deleted (was untracked stale junk)

Updated the smolvla2 / pi052 `recipe_path` config defaults, all
docstring / comment references, the annotation-pipeline + recipe docs,
and the three tests that loaded pi05_hirobot.yaml (repointed to the
renamed recipes; the low-level-branch and pipeline-render assertions
now accept a flow-only `low_level` stream as valid supervision, since
the new recipes' low_level_execution has no text-CE target).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 16:02:15 +02:00
Pepijn 7128bb1769 fix(annotate): decode keyframes via PyAV directly
The pyav fallback routed through lerobot's decode_video_frames(backend=
"pyav"), which uses torchvision.io.VideoReader — removed in torchvision
0.23+. On modern torch stacks (e.g. vllm-openai with torchvision 0.26)
both torchcodec and that path fail, leaving interjection/vqa prompts
without visual context.

Add _decode_frames_av: a self-contained PyAV decoder that picks the
nearest frame per timestamp. It is the always-available tail of the
decoder chain (torchcodec -> pyav) and the target of --video_backend=pyav.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:45:04 +02:00
Pepijn 426d48dbbf fix(pi052): port the smolvla2 text-head fixes to pi052
pi052 had the same text-CE collapse bug smolvla2 had — PaliGemma's
embed_prefix flags the language block att=0, so make_att_2d_masks makes
it fully bidirectional and the text cross-entropy degenerates into a
copy task. Ported the three model-specific fixes:

- _mark_target_span_causal: set att=1 on supervised target language
  positions so the text-CE is genuine causal next-token prediction.
  Applied in both _compute_all_losses_fused and _compute_text_and_fast_loss.
- flow_loss_weight 10.0 -> 5.0: the paper's a=10 swamps the LM head once
  the flow-only low_level recipe fires often (matches SmolVLA2Config).
- _flatten_say_tool_calls in the text tokenizer: serialize `say` tool
  calls into a <say>...</say> marker so the spoken reply is tokenized
  and supervised (PaliGemma's flat prompt has no structured calls, so
  they were dropped entirely).

select_message needed no change: pi052's prefix is [images, language]
with no trailing state token, so it already decodes from the last
language token.

Regression tests mirror the smolvla2 attention-masking + tool-call suite.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:42:19 +02:00
Pepijn fbcb9225f5 feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)
VQA annotations are sparse, so VQA was badly underrepresented in training:
its effective share was weight x density, and blend draws that picked an
ask_vqa* sub-recipe for a non-VQA frame were wasted entirely.

Two pieces:

1. Recipe-side consumption (language_render.py): render_sample now routes
   any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe,
   regardless of the weighted blend draw. No VQA annotation is wasted and no
   draw lands on a non-renderable VQA recipe — VQA's recipe-side share now
   equals the VQA-annotation density.

2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction):
   a new weighted, episode-aware sampler draws frames with replacement by
   per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the
   train script scans language_events, weights VQA frames so they make up
   ~that fraction of the training stream, and uses the weighted sampler. This
   is what actually lets VQA exceed its natural density. Default None keeps
   uniform episode-aware sampling unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:30:00 +02:00
Pepijn 31e0c15e55 fix(annotate): pyav fallback when torchcodec keyframe decode fails
VideoFrameProvider decoded keyframes via torchcodec only. Some containers
(e.g. vllm-openai) ship a torchcodec that cannot push packets to the
decoder ("Operation not permitted"), silently degrading interjection/vqa
prompts to no visual context.

_decode now retries with pyav when the default backend raises, and a new
`video_backend` config field lets callers pin the backend explicitly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:23:53 +02:00
Pepijn c5676ef1b3 feat(annotate): add dest_repo_id for separate push target
Adds an optional `dest_repo_id` to AnnotationPipelineConfig. When set,
`push_to_hub` uploads the annotated dataset there instead of overwriting
the source `repo_id`, restoring separate source/destination repos.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 15:05:23 +02:00
Pepijn b319ccf688 fix(smolvla2): only prompt for a camera when a VQA overlay is drawn
The VLM already sees every camera, so the operator never needs to name
one to ask a question. Move the camera prompt to after generation and
only fire it when the answer actually carries a bounding box / point
(whose pixel coordinates are camera-specific and need a target frame).
Non-spatial answers (count / attribute / spatial / plain text) now skip
the prompt entirely.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:50:19 +02:00
Pepijn 3174e14bc0 fix(smolvla2): feed all cameras to VQA generation, not just the chosen one
handle_vqa_query filtered the observation down to the single chosen
camera before calling the VLM. But training feeds every camera: the
ask_vqa_* recipes' image blocks are stripped before tokenization and
the frames reach the model via OBS_IMAGES_*, where embed_prefix
consumes all config.image_features regardless of the per-camera recipe
tag. Filtering to one camera changed the image-token count in the
prefix (the dropped camera zero-padded with mask=0) — a prefix shape
the model never saw at training.

Now the full observation is passed to select_message; the chosen
camera is used only to pick which frame the bbox/point overlay is
drawn on.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:46:38 +02:00
Pepijn dc530e10fe feat(smolvla2): VQA example prompts in the panel; drop quotes from hints
Command arguments never needed quotes (`_strip_quotes` only strips a
matching pair if present) — `/question point to the yellow cube` works.
The hints wrongly implied `""` were required; all hints/help now show
`/action <task>` / `/question <text>`.

Also adds a reference line to the state panel showing the two
overlay-producing VQA prompt shapes:
  /question point to the yellow cube   -> point overlay
  /question detect the blue cube       -> bounding-box overlay
plus the same examples in /help.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:42:32 +02:00
Pepijn e7c5613a39 refactor(smolvla2): command-driven runtime — no startup prompts
Replace the startup mode prompt + task picker with a single
command-driven prompt. The runtime now comes up immediately at the
command line in `paused` mode (robot idle) and the operator drives it:

  /action "task"     run the robot on a task (bare = resume, number = timed burst)
  /pause             stop the action loop — robot holds position
  /question "..."    pause and answer one VQA question (camera prompt + overlay)
  /help / stop

- Removed _select_mode_interactively / _select_task_interactively /
  _dataset_task_strings (the interactive pickers).
- mode value renamed "question" -> "paused"; --mode choices are now
  action|paused (default paused).
- /question takes the question inline and runs it via _handle_slash_command
  (pauses first, so the policy isn't used concurrently).
- The ENTER-to-start gate only fires when starting in action mode.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:37:51 +02:00
Pepijn 516ffc7687 feat(smolvla2): --mode flag, skip task picker with --task, timed /action
Lets the operator skip the interactive startup entirely and go straight
to the command line:

- New --mode {action,question} arg; when given, the startup mode prompt
  is skipped.
- When --task is passed explicitly on the CLI, the startup task picker
  is skipped (the dataset-bootstrap task still shows the picker so you
  can override it).

Also adds a timed action burst: /action <seconds> runs the robot for N
seconds, then the autonomous loop auto-reverts to question mode and
clears the action queue. Plain /action stays unlimited.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:26:12 +02:00
Pepijn 7a68bf13d9 feat(recipes): add hirobot_memory — hirobot + memory + spoken tool-call replies
New recipe alongside hirobot.yaml (kept as the lean baseline). Superset
that adds two text-supervised sub-recipes:

- memory_update: compress progress into a memory note.
- user_interjection_response: reply to a user interjection with a `say`
  tool call only (no plan/subtask text). The SmolVLA2 chat tokenizer
  flattens the call to a `<say>...</say>` marker the runtime parses back.

Plan is intentionally omitted; memory is the only persistent high-level
state. Weights: low_level 0.40, subtask 0.25, memory 0.10, interjection
0.10, vqa 0.075 x2.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:21:41 +02:00
Pepijn 15229468d0 feat(smolvla2): startup mode prompt; rename /vlm mode to /question
Add a mode prompt at startup, shown before the task picker, so the
operator chooses action (run the robot) vs question (VQA only) up front
instead of having to discover /vlm mid-run.

Also rename the VQA mode from "vlm" to the clearer "question":
- state["mode"] value is now "action" | "question"
- the command is /question (/vlm and /vqa kept as aliases)
- panels, hints and help text updated to match

handle_vqa_query now reports via both push_log and direct stdout, so
VQA answers / overlay paths are visible in autonomous question mode
where the panel redraw is suspended.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:17:03 +02:00
Pepijn a9cea3e8dd fix(smolvla2): make the autonomous REPL usable for slash commands / VQA
The autonomous panel redraw cleared the screen every 0.5s, so the "> "
prompt and the one-shot command hint vanished — the operator could not
see what to type or what they were typing, making /vlm unreachable.

- Suspend the timer redraw entirely while in /vlm mode (the action loop
  is paused, nothing changes in the background) so the VQA question and
  camera prompt stay on a stable screen.
- Re-print the "> " prompt after each redraw so it is always visible.
- Show an always-on command hint in the panel (/vlm, /help, /action)
  instead of relying on the startup line that scrolls away.
- Redraw immediately after a slash command so the mode flip is visible.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:10:13 +02:00
Pepijn 89d4846590 fix(smolvla2): always show the startup task picker on a TTY
The picker was skipped whenever a task was already resolved — which is
always the case with --dataset.repo_id, since the dataset's canonical
task is auto-filled. The operator never got to choose. Now the picker
always runs on an interactive terminal: the resolved task is shown as
"(current)" and selected by an empty Enter, so the dataset-canonical
default still works while letting the operator pick another task or
type a custom one.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:04:53 +02:00
Pepijn Kooijmans 9dfc9084e1 review: decode keyframes via video_utils.decode_video_frames
Addresses three of CarolinePascal's frames.py comments (the fourth, the
subprocess re-encode, waits on #3611):

- replace the bespoke _decode_pyav_direct PyAV decoder with
  lerobot.datasets.video_utils.decode_video_frames (torchcodec backend,
  PyAV fallback) — torchvision's VideoReader removal no longer applies
- frames flow through the provider as torch.Tensor (C, H, W uint8); PIL
  is materialised only at the VLM-message boundary in to_image_blocks /
  to_video_block, where the chat backends need it
- _decode now returns exactly one frame per timestamp (or [] on failure),
  so frames_at pairs them with strict=True

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 14:00:38 +02:00
Pepijn Kooijmans fd18beb3a1 review: address CarolinePascal feedback
- name the three modules everywhere (plan / interjections / vqa) instead
  of module_1/2/3 — config classes, config fields, executor params,
  staging keys and phase names now carry the module name
- rename examples/annotation -> examples/annotations; add the Apache
  header to run_hf_job.py
- drop the unused GeneralVqaModule._generate_one
- remove "PR 1" references from comments/docstrings
- frames.py: rely on the always-defined LeRobotDatasetMetadata.camera_keys
- executor.py: read/write meta/info.json via load_info / write_info
- reader.py: load meta/tasks.parquet via io_utils.load_tasks
- make --push_to_hub a bool; push the annotated dataset back to --repo_id
- move the on-disk test dataset builder into tests/fixtures
  (build_annotation_dataset); run_e2e_smoke reuses it
- clarify in the docs that the vqa module grounds each pair on a single
  frame (K = per-tick anchor count)
- hoist stdlib dynamic imports to module scope

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 12:03:25 +02:00
Pepijn 26cb38a7d0 feat(smolvla2): startup task picker, /vlm mode toggle, interactive VQA overlay
Three additions to the SmolVLA2 interactive runtime:

1. Startup task picker — when no --task is given, the runtime lists the
   dataset's task strings as a numbered menu (plus a custom-task option)
   instead of silently waiting for the first stdin line.

2. Mode toggle — /action and /vlm slash commands flip a persistent run
   mode. /vlm pauses the whole action loop (HighLevelSubtaskFwd,
   LowLevelForward and DispatchAction gate on state["mode"]) and clears
   the action queue so the robot holds position; /action resumes it.
   The mode is shown in the state panel.

3. Interactive VQA — in /vlm mode a typed line is a VQA question. The
   new inference/vqa.py module asks which camera to ground on, runs the
   VLM on that single camera, and when the answer is a bbox/keypoint it
   draws the overlay, saves a PNG to ./vqa_overlays/ and auto-opens it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 11:20:57 +02:00
Pepijn bfb8cfb432 fix(smolvla2): flatten say tool_calls into <say> marker before tokenizing
The chat tokenizer passed assistant `tool_calls` straight to
`apply_chat_template`, which renders them as a structured JSON
`<tool_call>` block — so the LM head was trained to emit JSON. But the
inference parser `_split_plan_and_say` looks for a `<say>...</say>`
marker, which the model never saw in training, so the `say` tool never
fired at inference.

`_flatten_say_tool_calls` is the missing training-time serializer (the
one `_split_plan_and_say`'s docstring already assumed existed): it
rewrites a `say` tool call into a `<say>...</say>` marker inside the
content text before the chat template runs, so the template only
tokenizes plain text and the supervised target span trains the model to
emit exactly the marker the runtime parses back (Pi 0.5-style flat
tool-call serialization).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 10:47:31 +02:00
Pepijn 5e3b9ba82c tune(smolvla2): override optimizer_lr to 2.5e-5 for pretrained-LM fine-tuning
SmolVLA's 1e-4 is safe only because it freezes the language head. SmolVLA2
unfreezes lm_head + the last text layer and fine-tunes the pretrained
SmolVLM2 language weights; 1e-4 is too aggressive there and destabilises
generation into degenerate repetition. Match pi05's 2.5e-5 peak LR.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-18 10:41:13 +02:00
Pepijn 083d3cd419 tune(smolvla2): soften flow:text loss split from 10:1 to 5:1
The Pi 0.5 α=10 split assumed text is a rare auxiliary task. With the
flow-only `low_level` recipe (~40% of the blend) now rendering, the flow
term fires often and at 10x weight dominates the shared VLM backbone,
starving the text head into degenerate repetition decoding. A 5:1 split
keeps actions primary while leaving the language head enough gradient.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-17 16:00:08 +02:00
Pepijn bf996c7938 fix(datasets): render flow-only low_level recipes instead of dropping them
A recipe whose only supervision is the action-expert flow loss (e.g.
`low_level_execution`: `user(${subtask})` with `stream: low_level` and no
`target` turn) was rejected at render time by `_render_message_recipe` and
`_validate_rendered`, both of which required at least one target turn.

The result: every blend draw of the flow-only recipe rendered to `None`,
`predict_actions` was never set, `run_flow` never fired, and the action
expert received no flow loss — leaving it at random init. Both gates now
also accept a `low_level`-stream turn as valid supervision.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-17 13:20:39 +02:00
Pepijn 0d88eaf8eb test(smolvla2): attention masking of the language target span
Regression coverage for the text-CE collapse bug fixed in 3cd348ff.
Pure-function tests over ``_mark_target_span_causal`` /
``_locate_lang_range`` / ``make_att_2d_masks`` — no model load, fast.

Pins:
* the target span flips to att=1, prompt/images stay att=0;
* target tokens attend causally among themselves (no peeking at
  future targets) — genuine next-token prediction;
* targets still attend bidirectionally to images + the user prompt;
* the action-expert (state) token still attends to every target;
* a no-target subtask (low_level_execution user turn, labels all
  -100) leaves the mask bidirectional;
* an explicit test documenting the bug: the raw embed_prefix mask
  lets the first target token see the last — the copy-task collapse.

Skips cleanly when transformers isn't installed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-16 18:28:44 +02:00
Pepijn 3cd348ffe2 fix(smolvla2): causal mask on the text-CE target span (THE collapse bug)
Root cause of every collapsed inference run. ``embed_prefix`` flags
all language tokens ``att=0``; ``make_att_2d_masks`` turns that into
a single fully BIDIRECTIONAL block. So during the text-loss forward,
a supervised subtask token's hidden state attends to the very tokens
it is trained to predict. The cross-entropy degenerates into a copy
task — ``text_loss → ~3e-5`` not because the model learned to
predict subtasks but because it can see the answer.

At inference ``select_message`` decodes autoregressively (causally):
each token must be predicted WITHOUT seeing it — a task the model
was never actually trained on. Hence the universal collapse: a
coherent first token or two ("grasp the yellow cube"), then a loop
("cover cover cover", "icatorsicators", "the the the").

Fix: ``_mark_target_span_causal`` sets ``att=1`` on the language
positions that are supervised targets (``text_labels != -100``).
With make_att_2d_masks's cumulative-block rule each target token
then attends to images + the user prompt bidirectionally and to
EARLIER target tokens only — genuine causal next-token prediction,
matching select_message. Applied in both ``_compute_text_loss`` and
``_compute_fused_loss``. Per-sample correct: high_level_subtask
targets become causal; low_level_execution subtasks (a user turn,
labels all -100) stay bidirectional so the action expert reads them
as bidirectional context. The action expert is otherwise unaffected
— the suffix has a strictly higher cumsum and still attends to the
whole prefix.

Requires retraining: this changes the training objective. Existing
checkpoints were all trained on the degenerate copy task and cannot
generate text. Expect ``text_loss`` to settle MUCH higher than 3e-5
after this — that is correct; it is now a real prediction task.

NOTE: pi052's text path (PaliGemma prefix-LM) has the same
bidirectional-block structure and needs the analogous fix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-16 18:24:44 +02:00
Pepijn db03fc6dc4 fix(smolvla2): select_message must decode from the language position
``embed_prefix`` lays the prefix out as ``[images, lang, state]`` with
the state token LAST. Training supervises the text head on the
*language* positions (``_compute_text_loss`` / ``_compute_fused_loss``
slice ``prefix_out[lang_start:lang_end]`` and run lm_head there).

But ``select_message`` started AR generation from the full prefix and
read ``prefix_out[:, -1:]`` — the **state token** — to decode the
first subtask token. The state token's hidden state exists for the
action expert to read; the lm_head was never trained to produce
subtask text from it. So inference decoded the high-level head from a
position entirely outside the training distribution: the text head
collapses (``the arm the arm``, ``grasp the surface population``,
``_333 absburg…``) no matter how cleanly ``text_loss`` converged.

Fix: truncate the state token off the prefix before the AR loop, so
``prefix_out[:, -1:]`` is the last language token (right after the
``Assistant:`` generation prompt) — exactly where training supervised.

Inference-only change — no retraining needed; existing checkpoints
benefit immediately. The action path (``predict_action_chunk``) is
untouched: state belongs in the action expert's prefix.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-16 15:05:16 +02:00
Pepijn 56068d37ea fix(smolvla2): default load_vlm_weights=True — don't train from scratch
SmolVLAConfig defaults ``load_vlm_weights=False``. With that and no
``--policy.path``, ``SmolVLMWithExpert.__init__`` builds the VLM via
``SmolVLMForConditionalGeneration(config=...)`` — i.e. a fully
**random-initialised** 500M backbone, including a random ``lm_head``.

For plain SmolVLA that's a deliberate "pre-train the expert" mode.
For SmolVLA2 it's a footgun: the high-level text head *is* the
SmolVLM2 ``lm_head``. Training subtask prediction from a random
language model can only memorise — which is exactly the repetition
collapse seen on the real robot ("the arm the arm the arm …").

SmolVLA2 now defaults ``load_vlm_weights=True`` so every run
fine-tunes the pretrained ``HuggingFaceTB/SmolVLM2-500M-Video-Instruct``
backbone (vision tower + language model + lm_head). The action
expert still trains from scratch on the robot data (standard SmolVLA
fine-tuning); start it from pretrained too by fine-tuning a full
``lerobot/smolvla_base`` checkpoint via ``--policy.path``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 16:44:00 +02:00
Pepijn e727688052 annotate: telegraphic subtasks — ≤4 words, verb+object, consistent nouns
Tighten the subtask prompt further per real-data feedback. The old
≤5-word cap still produced things like "release the yellow block
into the green bin" (8 words, articles, destination, and "block"
where the task said "cube").

New rules:
* Hard cap ≤ 4 words, ideally 2-3. Form: VERB + (color) + OBJECT.
* No articles, no destinations, no adverbs, no "robot/arm/gripper".
* Must reuse the exact object nouns from the task — no block/cube,
  bin/box/container drift across the episode.
* Concrete good/bad examples anchored on the cube task.

Shorter, templated, consistent targets are far more robust for the
autoregressive LM head — fewer tokens to drift on, fewer dominant
n-grams to repetition-collapse into.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 14:14:42 +02:00
Pepijn f1a0a663cc fix(inference): gibberish detector catches long repetition collapse
The ``_looks_like_gibberish`` low-unique-token check was gated on
``len(stripped) < 80``, so an LM head that loops an n-gram for the
whole 256-token budget — "the arm the arm … the the the the" —
sailed straight through (``gibberish:0`` in the panel) and the
garbage subtask got accepted and fed to the action expert.

Added a length-independent check: ``>= 8 tokens`` but unique-token
count ``<= max(3, tokens // 10)`` ⇒ repetition collapse. Now the
runtime rejects the looped output and keeps the previous (real)
subtask instead of propagating nonsense.

This is a guard, not a cure — the underlying issue is the LM head
on the current checkpoint being undertrained / collapsed; re-
annotate with the short prompts and train longer.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 13:52:26 +02:00
Pepijn 6e64c20cf1 runtime: stop seeding plan/memory from the dataset (unused)
The current recipe trains neither plan nor memory, and no inference
step consumes them — ``_msgs_for_subtask`` renders the bare task and
``LowLevelForward`` conditions on the subtask. Bootstrapping
``current_plan`` / ``current_memory`` from the dataset's
``language_persistent`` annotations therefore only placed a stale,
do-nothing plan in the status panel.

Keep seeding ``current_subtask`` — it's a useful first-frame
fallback for ``LowLevelForward`` before ``HighLevelSubtaskFwd``
produces its first subtask.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 13:47:33 +02:00
Pepijn b29cccb37e runtime: restore the subtask hierarchy — generated subtask drives actions
Reverts the previous "condition actions on the task" shortcut.
The action expert is conditioned on the SUBTASK again:

* ``low_level_execution`` recipe back to ``user(${subtask})``.
* ``LowLevelForward`` conditions on ``current_subtask`` (falls back
  to the task only on the first frame, before the high-level loop
  has produced a subtask).
* ``HighLevelSubtaskFwd`` re-added to the runtime pipeline so the
  subtask is actually generated each high-level tick and written to
  ``current_subtask`` before ``LowLevelForward`` consumes it.
* ``_msgs_for_subtask`` now renders just ``${task}`` (no
  ``Plan: ``/``Memory: `` lines) to match the current
  ``high_level_subtask`` recipe, whose user turn is the bare task.

So the loop is: task → HighLevelSubtaskFwd (LM head) → subtask →
LowLevelForward → action chunk conditioned on that subtask.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 13:43:04 +02:00
Pepijn f161e27e96 recipe+runtime: condition the action expert on the task, not the subtask
Real-robot runs shook and failed the task despite a low flow loss.
Root cause: train/inference conditioning mismatch — not a flow-loss
bug (``_compute_fused_loss``'s flow path is byte-identical to
``SmolVLAModel.forward``).

At training, ``low_level_execution`` conditioned the action expert
on ``${subtask}``, and every frame's subtask was the correct one
for that frame. At inference the runtime has no high-level subtask
generator (VQA-only pipeline), so ``current_subtask`` was frozen —
the action expert got "move towards the blue cube" for the entire
episode. Once the arm reached the cube, that (image, subtask) pair
never occurred in training → OOD conditioning → incoherent flow
output → shaking.

Fix: ``low_level_execution`` now renders ``user(${task})``. The
task is stable for the whole episode and always available, so the
action expert's conditioning is identical at train and inference
with no high-level loop required. ``LowLevelForward`` updated to
build the same ``[user(task)]`` prompt.

``high_level_subtask`` still trains the text head to predict
subtasks (kept for when a reliable subtask loop is reintroduced) —
it's just no longer on the action expert's critical path.

Requires re-training for the recipe change to take effect.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 13:40:15 +02:00
Pepijn d5f293a1c9 recipe+runtime: VQA + subtask only — drop plan & memory
Scope reduction while the core subtask + action loop is validated:

Recipe (hirobot.yaml)
* Removed ``plan_generation`` sub-recipe entirely.
* Removed the memory tail from ``high_level_subtask`` (the
  ``new_memory`` binding + the second assistant turn).
* ``high_level_subtask`` user turn is now just ``${task}`` — no
  ``Plan: …\nMemory: …`` context.
* Weights rebalanced over the four remaining sub-recipes:
  high_level_subtask 0.40, low_level_execution 0.40,
  ask_vqa_top/wrist 0.10 each.

Runtime (inference/runtime.py)
* Pipeline trimmed to VQA + the action loop:
  AskVQAFwd → LowLevelForward → DispatchAction → DispatchToolCalls.
* Dropped HighLevelSubtaskFwd / MemoryUpdateFwd / UserInterjectionFwd
  from the default pipeline. They remain importable from
  ``inference.steps`` for when plan/memory/subtask generation is
  brought back. The action expert conditions on the task string
  directly via LowLevelForward's ``current_subtask or task``
  fallback.

This commit lands on top of a rollback of the previous two commits
(repetition_penalty / no_repeat_ngram_size knobs, and the
deterministic plan-walker) — both were bandaids for the LM-head
repetition collapse that the reduced-scope recipe sidesteps.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-15 08:02:06 +02:00
Pepijn 95033733fc deps: add sentencepiece to the pi extra (FAST action tokenizer)
PI052 and PI0_FAST both load ``physical-intelligence/fast`` as
their action tokenizer. That tokenizer's HF backend requires
``sentencepiece`` to instantiate (or ``tiktoken``); without it
``AutoProcessor.from_pretrained`` raises:

  ValueError: Couldn't instantiate the backend tokenizer from one of:
  (1) a tokenizers library serialization file,
  (2) a slow tokenizer instance to convert or
  (3) an equivalent slow tokenizer class to instantiate and convert.
  You need to have sentencepiece or tiktoken installed [...]

It wasn't listed in pyproject so fresh installs missed it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 17:52:55 +02:00
Pepijn c3503b774f fix(debug): dumper now shows real stream + target flags
The dumper was printing ``stream=None target=None`` for every
message because it read those fields off the message dicts, but
the recipe renderer keeps them in parallel arrays
(``message_streams`` / ``target_message_indices`` in
COMPLEMENTARY_DATA) so the chat template doesn't see unknown
keys. Zip them back into the dump-time dicts so the printed
metadata is accurate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 16:43:51 +02:00
Pepijn 99ebee4d16 annotate: tighter subtask + memory prompts (≤5 / ≤10 words)
Both feed into the high-level prompt and the plan rendering, so
keeping them short directly reduces the rendered ``${task}\nPlan:
…\nMemory: …`` prefix the model has to chew through at inference.

Subtasks
* Hard cap: ≤ 5 words. Verb + object only, drop articles/adverbs.
* Concrete good/bad examples to anchor the VLM.

Memory
* Hard cap: ≤ 10 words. Telegraphic noun→location fragments
  ("bowl in box, lid open"), no past-tense verbs, drop attributes
  that don't matter for downstream subtasks.
* Allow empty string when no material change occurred — keeps the
  rendered memory line literally blank instead of forcing a no-op
  sentence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 16:28:09 +02:00
Pepijn a8ca5128b8 fix(annotate): re-emit plan at every subtask boundary
Previously only emitted a plan at t=0 and on interjections, so the
active plan rendered into training carried "done" subtasks until
the next interjection. With the new "plan = remaining subtasks"
summariser this meant the plan was stale between boundaries.

Emit a fresh plan row at every subtask start. ``active_at(t)`` then
returns a plan that contains exactly the subtasks whose start ≥
the current span's start — completed subtasks fall off the plan
the moment the next subtask begins.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 16:26:49 +02:00
Pepijn dd97c33814 refactor(annotate): plan = summary of still-todo subtasks, drop VLM call
The plan was being generated by a separate VLM call (one per
episode + one per interjection refresh) with a prompt that asked
the model to "compress the subtasks into a compact hierarchical
plan". In practice the plans came out longer than necessary and
sometimes drifted from the actual subtask sequence the runtime
would execute.

Replaced ``_generate_plan`` with a deterministic numbered list
of the upcoming subtasks. At a refresh time the list shrinks to
subtasks whose start ≥ refresh_t — the plan describes what's
*left* to do, so it gets shorter as work progresses.

Saves the per-episode + per-interjection VLM round-trip in the
annotation pipeline and keeps train-time plan text bit-aligned
with the subtask annotations the rest of Module 1 emits.

Removed the now-unused ``prompts/module_1_plan.txt``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 15:55:02 +02:00
Pepijn fa45ba631b fix(policies,recipe): register PI052Config + allow flow-only sub-recipes
Two regressions surfaced by the first training run:

1. ``--policy.type=pi052`` failed with ``invalid choice``. PI052Config
   wasn't imported in ``policies/__init__.py``, so its
   ``@register_subclass("pi052")`` decorator never ran and draccus
   didn't see it as a valid policy type. Mirror PI05Config /
   SmolVLA2Config in the top-level imports + __all__.

2. ``low_level_execution`` (user-only ``${subtask}`` recipe used for
   π0.5-style flow conditioning) tripped
   ``ValueError: Message recipes must contain at least one target
   turn.`` The validator was too strict — a recipe with only a
   ``stream: low_level`` turn still drives meaningful supervision
   (flow MSE on the action expert via ``predict_actions=True``).
   Allow either ``target: true`` OR ``stream: low_level`` to satisfy
   the "supervises something" requirement.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 15:51:47 +02:00
Pepijn ffd8c92ce5 fix(inference): always emit Plan:/Memory: labels in the high-level prompt
The recipe renders ``"\${task}\nPlan: \${plan}\nMemory: \${memory}"``
unconditionally — when a binding resolves to None,
``language_render._substitute`` substitutes an empty string, so the
training-time user turn always contains the literal ``Plan: `` /
``Memory: `` prefixes even with empty values.

The inference message builders were skipping those lines entirely
when ``state['current_plan']`` / ``state['current_memory']`` was
empty, producing a different prompt shape on early frames (before
the plan-generation step runs) and on datasets without plan/memory
annotations.

Factored a shared ``_hirobot_user_head`` helper used by
``_msgs_for_subtask``, ``_msgs_for_memory``, and the legacy
``_control_context_messages`` so they all match training byte-for-
byte regardless of which bindings are populated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 15:42:29 +02:00
Pepijn 841d3c47e1 feat(debug): LEROBOT_DUMP_RECIPE_SAMPLES=N dumps the first N rendered samples
Adds a one-shot debug dumper to both chat processors. When the env
var ``LEROBOT_DUMP_RECIPE_SAMPLES`` is set to a positive integer N,
the next N samples processed (rank-0 only) get pretty-printed:

* the recipe-rendered messages (role / stream / target / content),
* the full tokenized prompt (decoded back),
* inline ``[TGT]...[/TGT]`` markers over the spans the LM head is
  supervised on,
* token count + target-token count,
* ``predict_actions`` flag.

Usage:

  LEROBOT_DUMP_RECIPE_SAMPLES=5 sbatch train_smolvla2.slurm

After N dumps the helper becomes a no-op; training continues
unaffected. Works for both smolvla2 (chat-template renderer) and
pi052 (plain ``Role: content`` concat renderer); each processor has
its own copy to avoid cross-package imports.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 15:21:46 +02:00
Pepijn 2c920ab178 refactor(recipes): consolidate to shared hirobot.yaml + audit fixes
The smolvla2 and pi052 recipe blends had drifted to identical content
twice in a row; collapse them to a single ``recipes/hirobot.yaml``
both policies point at. Each backbone's text tokenizer (chat-template
for SmolVLA2, plain ``Role: content`` for PI052) handles the
rendering differences downstream — the recipe spec is shared.

Audit fixes folded into the same commit:

* **Train/inference prefix mismatch on the action expert**
  ``_build_text_batch`` always passed ``add_generation_prompt=True``,
  appending ``<|im_start|>assistant\\n`` tokens that the action
  expert never saw at training (the chat tokenizer renders with
  ``add_generation_prompt=False``). Parameterized the helper and
  pass ``False`` from ``LowLevelForward``; ``select_message`` paths
  still default to ``True`` for AR text generation.

* **PI052 fallthrough could silently train flow on text-only frames**
  When ``text_loss_weight=0`` AND every sample was high-level
  (``predict_actions.any()==False``), the previous heuristic
  delegated to ``PI05Policy.forward``, which ignores
  ``predict_actions`` and runs flow on every sample. Reverted to
  delegating only on fully unannotated batches.

* **SmolVLA2 silent zero-loss training**
  ``forward`` returned ``loss=0`` (no error) when neither flow nor
  text path fired. Now raises ``RuntimeError`` with the weights and
  routing flags — fails loud like PI052 already does.

* **PI052 dropout-seed key**
  Was reading ``complementary["dataset_index"]`` (only set by
  ``MultiDataset`` and means "which sub-dataset", not row index)
  with fallback to ``frame_index`` (never set) — every sample got
  seed=0, so per-component dropout was deterministic across the
  epoch. Switched to ``complementary["index"]`` to match SmolVLA2
  and the canonical ``BatchProcessor`` convention.

* **Dead ``DEFAULT_TOOLS`` import**
  Removed from ``chat_processor_smolvla2.py`` — unused since the
  default-tools list was switched to ``[]`` in the prior commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 15:16:28 +02:00
Pepijn 9f630e2a41 fix(recipes,training): stop tool prompt leak + drop subtask copy-supervision
CRITICAL (smolvla2) — the SmolVLM2 chat template was rendering the
``say`` tool's JSON schema as a system message on every training
sample because ``DEFAULT_TOOLS`` was the default in
``SmolVLA2ChatTokenizerStep``. That schema was only relevant to
the now-removed ``user_interjection_response`` recipe; with it
gone the schema is dead weight that polluted every action-expert
prefix AND created a train/inference mismatch (the inference
``_build_text_batch`` doesn't pass ``tools=``). Default is now
``[]``; callers needing tools can still set them via
``with_tools(meta.tools)``.

LIKELY-BUG — ``low_level_execution`` had ``target: true`` on its
assistant turn, so text-CE trained the LM head to predict the
same subtask string the user just stated (trivial "copy previous
turn" supervision that diluted LM head capacity). Dropped the
assistant turn entirely; ``high_level_subtask`` (w=0.50) already
owns subtask prediction from real context.

The chat-tokenizer's ``predict_actions`` detection used to scan
target streams only. With the new no-target low_level recipe it
would mis-fire as False. Switched both
``chat_processor_smolvla2.py`` and ``text_processor_pi052.py`` to
scan all message streams — any ``stream: low_level`` on the
sample is enough to trigger flow loss.

Inference: the low-level loop sends only ``[user(subtask)]`` now,
matching the new recipe shape.

PI052 — hardened the forward fallthrough so a degenerate batch
where every sample's recipe is text-only AND text supervision is
disabled (text_loss_weight<=0 or text_labels missing) cleanly
delegates to ``PI05Policy.forward`` instead of raising
"nothing to train".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 14:59:01 +02:00
Pepijn 7a32f8a72a refactor(recipes): π0.5-style split — action expert conditions on subtask only
Previously ``action_execution`` rendered ``task + plan + memory +
subtask`` into one prefix and ran the flow loss on it. That meant
the action expert was conditioned on the full hierarchical context
(closer to π0.7 §V.A), not just the subtask.

The π0.5 paper's hierarchical inference has the action expert see
only the *subtask* (plus images and state). Split the recipe to
match:

  high_level_subtask  (0.50)
    user(task + plan + memory) → assistant(subtask)
    [+ assistant(new_memory) at boundary frames]
    All ``stream: high_level`` → text-CE only, no flow loss.

  low_level_execution (0.30)
    user(subtask) → assistant(subtask)
    Both ``stream: low_level`` → flow loss fires; text CE on the
    subtask is a small redundant extra signal. Prefix the action
    expert sees: [images, subtask, state].

  plan_generation (0.10) — unchanged.
  ask_vqa_{top,wrist} (0.05 each) — unchanged.

Runtime: the low-level loop in ``smolvla2/inference/steps.py``
now sends ``[user(subtask), assistant(subtask)]`` to
``predict_action_chunk`` instead of the full task+plan+memory
context. Falls back to ``state['task']`` when no subtask has been
generated yet so the first frame still has something to condition
on.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 14:13:07 +02:00
Pepijn 129aa207e3 fix(smolvla2,pi052): training-correctness audit fixes
CRITICAL (smolvla2) — text-CE was applied to the wrong prefix slice.
``num_state`` was being read from ``state.shape[1]`` (the raw
max_state_dim, ~14-32) instead of the *number of state tokens*
(always 1). Compounded by the trailing-padding issue (state is
not at the end of the padded prefix when ``seq_len < prefix_length``),
the lang slice was landing on image / padding hidden states.

New ``_locate_lang_range`` finds the state position via
``att_masks.nonzero()`` (the only ``1`` in the mask), making the
slice robust to both bugs. Used by ``_compute_text_loss`` and
``_compute_fused_loss``.

LIKELY-BUG (smolvla2) — ``_unfreeze_lm_head`` only re-enabled
``lm_head`` and ``text_model.model.norm.weight``. SmolVLA's parent
ALSO freezes the last 1-2 transformer layers, so text-loss
gradients died in a frozen final block. Now mirrors the parent's
freeze targets and unfreezes the matching ``layers.{N-1}`` (and
``N-2`` when num_vlm % num_expert == 0).

CRITICAL (pi052) — flow and FAST CE were not per-sample masked
under per-sample-routing. Text-only recipe samples
(``plan_generation``, ``ask_vqa_*``) contributed to flow/FAST
loss with prompts that deliberately omit the subtask, corrupting
the signal. Threaded ``predict_actions_t`` through both
``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss``;
flow uses ``(per_sample * mask).sum() / mask.sum()``, FAST uses
``shift_valid & sample_mask`` before ``masked_fill(-100)``.

OTHER
* PI052Policy.forward now falls through to PI05Policy.forward on
  unannotated batches (no text_labels, no predict_actions, no FAST).
* fit_fast_tokenizer cache key now includes ``chunk_size`` — changing
  the chunk size no longer silently loads a wrongly-fit tokenizer.
* Removed dead ``_compute_text_loss`` / ``_compute_fast_action_loss``
  in pi052 (superseded by the fused helpers).
* Fixed stale "no-op stub" docstring on ``knowledge_insulation`` —
  it's been fully wired since the per-layer KI forward port.
* Stripped unused ``copy`` / ``resize_with_pad`` imports.
* Extracted ``_shifted_ce`` / ``_mask_per_sample`` / ``_fast_ce``
  helpers shared between fused and prefix-only paths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 14:08:06 +02:00
Pepijn e3ad1c59fc feat(recipes): add plan_generation sub-recipe to smolvla2 + pi052 blends
New text-only sub-recipe at 0.10 weight on both blends:

    user      :  ${task}
    assistant :  ${current_plan}   (high_level target)

Bound to ``active_at(t, style=plan)`` so it supervises the
currently-active plan on every frame, gated by ``if_present`` to
skip frames without a plan annotation.

Weights rebalanced: action_execution 0.85 → 0.75, plan_generation
0.10, VQA top/wrist 0.075 each (sums to 1.0).

Added matching runtime builder ``_msgs_for_plan`` in
``smolvla2/inference/steps.py`` so the high-level loop can call
``select_message`` with the bare-task prompt at episode start /
replanning events.

Closes a gap vs. Pi 0.7 §V — without this recipe the model could
read ``${plan}`` from the prompt but never had to produce one.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 13:51:37 +02:00
Pepijn 9ff62cb08c docs(recipes): trim header comments, drop diversity-knobs note in run_hf_job
Recipes were over-commented (paper citations, history of removed
sub-recipes, inference-time loop walkthroughs). Stripped down to a
short header + a one-line note on the boundary-frame memory tail.

Also removed the ``_tool3`` diversity-knobs comment block in
``examples/annotation/run_hf_job.py`` — it was a personal note about
a since-merged experiment.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 12:55:03 +02:00
Pepijn b2aa372fcf refactor(recipes): fold memory into action_execution, drop interjection, fuse smolvla2 forward
Recipe changes:
* action_execution now bundles the memory update as a second
  assistant target gated on a new ``new_memory`` binding (fires
  only at subtask-boundary frames). No "Completed subtask: X"
  filler — the model emits the new subtask AND the updated
  memory back-to-back in one prefix.
* user_interjection_response sub-recipe removed (current
  datasets don't have interjection / say() annotations).
* Standalone memory_update sub-recipe removed (folded above).
* Weights rebalanced: action_execution 0.85, ask_vqa_top/wrist
  0.075 each (sums to 1.0).

Runtime ``_msgs_for_memory`` updated to match the new
boundary-frame prompt layout.

Modeling:
* SmolVLA2Policy now fuses the flow + text losses into a SINGLE
  backbone forward via ``_compute_fused_loss`` (one
  vlm_with_expert pass with [prefix, suffix] embeds, then both
  lm_head CE on lang slice + action_out_proj MSE on suffix).
  Mirrors pi052's existing ``_compute_all_losses_fused`` —
  saves one backbone pass per training step.

Examples:
* Removed the two training SLURM scaffolds; they were
  out-of-date with the recipe refactor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 12:51:09 +02:00
Pepijn 058b8f3958 refactor(recipes): two-flavor design — one fused action_execution + text-only events
Both smolvla2_hirobot.yaml and pi052_hirobot.yaml are rewritten as a
clean two-flavor blend, modelled on Pi 0.7 §V.A (Subtask instructions)
and the hierarchical inference pattern from Pi 0.5 §IV.D.

Flavor 1 — action_execution (60% weight, "main path")
-----------------------------------------------------

One always-on recipe that fuses **all** available context (task,
plan, memory) into a single user prompt and uses the current subtask
as the supervised assistant target. This single recipe supervises
*both* objectives:

  * subtask prediction (text CE on the assistant span via lm_head)
  * action chunks (flow MSE on the action expert via
    stream: low_level, target: true; plus FAST CE on action tokens
    when enable_fast_action_loss=True)

At inference, the *same* prompt structure drives both inference
modes:

  * select_message(user_prompt_only) → LM head generates the next
    subtask. Matches action_execution's training distribution
    exactly (prompt is the user turn, target is the subtask).
  * predict_action_chunk(user_prompt + assistant_subtask) → action
    expert produces the chunk. Matches action_execution's full
    prompt+target.

This replaces what used to be a separate high_level_subtask recipe
plus a low_level_execution recipe; both were supervising the same
subtask text, so collapsing them into one is correct and removes
the redundant text-CE gradient.

Flavor 2 — event-driven text-only recipes
-----------------------------------------

Each of these supervises the LM head to predict a specific kind of
text given a specific event-triggered context. ``stream: high_level``
on all targets so they never trigger predict_actions / flow loss.
``if_present`` guards ensure they only fire on frames where the
event annotation is present.

  * memory_update           (10%)  new memory at subtask boundary
  * user_interjection_response (15%) new plan + say(...) on input
  * ask_vqa_top             (7.5%) front-camera VQA
  * ask_vqa_wrist           (7.5%) wrist-camera VQA

Total weight = 1.0.

Prompt format consistency
-------------------------

User prompt template ``${task}\nPlan: ${plan}\nMemory: ${memory}``
matches what ``inference/steps.py::_msgs_for_subtask`` and
``_control_context_messages`` already emit at inference time. No
"Task: " prefix — the bare task string is used as the leading
content with literal "Plan: " / "Memory: " labels for the
subsequent components.

What changed structurally
-------------------------

  - low_level_execution            DROPPED  (folded into action_execution)
  - high_level_subtask             DROPPED  (subtask supervision moved into action_execution)
  + action_execution               NEW      (the fused main recipe)
    memory_update                  kept, prompt cleaned up
    user_interjection_response     kept, prompt cleaned up
    ask_vqa_top / ask_vqa_wrist    kept

Runtime compatibility
---------------------

No runtime change needed — ``SmolVLA2Runtime`` and the inference
helpers already build their high-level prompt as just the user turn
(task + plan + memory) and append a ``current_subtask`` assistant
turn for the low-level call. Both match the new ``action_execution``
prompt shape exactly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 12:35:51 +02:00
Pepijn b873fe454c perf(pi052): full fusion — text + FAST + flow in ONE backbone forward
Previously the forward did 2 backbone passes when all heads were
active: one for flow (via super().forward) and one for the fused
text+FAST helper. This commit reduces it to **one pass** — same
compute as flow-only training.

New ``_compute_all_losses_fused`` builds:

    prefix = [images, language, FAST (when provided)]
    suffix = [noisy_actions]  (action expert via gemma_expert)

and runs a single ``paligemma_with_expert.forward`` with
``inputs_embeds=[prefix_embs, suffix_embs]`` (both experts active
in the same call). Captures *both* prefix_out and suffix_out, slices
each for its respective loss:

    flow MSE     ← suffix_out  (existing action_out_proj + MSE path)
    text  CE     ← prefix_out at language positions (lm_head + CE)
    FAST  CE     ← prefix_out at FAST positions (lm_head + CE)

Critical attention mask override
--------------------------------

``make_att_2d_masks`` produces a cumulative-block attention mask in
which suffix tokens (highest cumsum) attend to *every* lower-cumsum
position by default, including FAST tokens. If we let that stand the
action expert reads the discrete FAST tokens and trivially decodes
them back to the same continuous actions the flow head is supposed
to predict from noise — the entire training signal collapses to a
copy operation.

The fix is a single line right after make_att_2d_masks:

    att_2d_masks[:, fast_end:, fast_start:fast_end] = False

Explicitly zeros out *suffix → FAST* attention. Everything else
remains correct under the cumsum semantics:

  * prefix images/language stay bidirectional among themselves
  * FAST stays causal within itself, attending bidirectionally
    to images+language
  * FAST cannot see suffix (cumsum < suffix cumsum, default)
  * suffix attends bidirectionally among itself, to images+language,
    and now NOT to FAST (this override)

Bit-equivalent to the previous separated forward path for text+FAST
losses (the prefix hidden states at language and FAST positions are
unchanged whether suffix is present or not — the prefix doesn't
attend to suffix). For flow loss, suffix→FAST being masked is the
correct behaviour we *want* — if anything the previous separated
path was less correct for production use because the joint
gradient signal through the action expert was missing the prefix
extension.

Forward routing in ``forward()``
--------------------------------

  * run_flow=True  →  _compute_all_losses_fused (one forward, all
                      three losses)
  * run_flow=False, run_text or run_fast → _compute_text_and_fast_loss
                      (one prefix-only forward, two CE losses, no
                      suffix → cheaper than fusion)
  * neither       →  RuntimeError (explicit; both losses disabled)

Wall-time per step
------------------

  Before this commit:  flow + (text+FAST fused) = 2 forwards
  After this commit:   (flow+text+FAST fused)   = 1 forward

Compute parity with flow-only training when all three heads active.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 12:28:38 +02:00
Pepijn 83d7250a22 fix(recipes): low_level_execution needs if_present:subtask guard too
Same bug we fixed for high_level_subtask, just on the other
subtask-supervised sub-recipe. ``low_level_execution`` targets
``${subtask}`` (the current active span) but had no
``if_present`` guard. When ``active_at(t, style=subtask)`` returned
None at a frame (gaps in the annotation, or the very first/last
frames of an episode if the annotator's spans don't fully tile),
the assistant message rendered with empty content. The chat
tokenizer still included it in ``target_message_indices`` → text CE
supervised whatever the chat-template's empty assistant turn
decoded to (usually a single ``\n``). That trains the LM head's
prior at the first generation position toward ``\n``, the same
collapse we observed with the original ``${next_subtask}`` target.

Fix: ``if_present: subtask`` on the assistant target in
``low_level_execution`` for both ``smolvla2_hirobot.yaml`` and
``pi052_hirobot.yaml``.

Side effect: frames without an active subtask span no longer
contribute to the flow loss either (the only ``low_level`` target
is skipped, ``predict_actions = bool(targets_by_stream.get("low_level"))``
becomes False). For a well-annotated dataset where subtask spans
tile the whole episode this is a no-op. For datasets with gaps,
those gap frames lose flow supervision — strictly better than the
degenerate text-CE alternative.

Sub-recipe audit summary (no other changes needed):

  * memory_update                 — all if_present guards present, OK
  * user_interjection_response    — all if_present guards present, OK
  * high_level_subtask            — fixed earlier, OK
  * low_level_execution           — fixed by this commit
  * ask_vqa_top / ask_vqa_wrist   — query+answer both guarded, OK

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 12:22:45 +02:00
Pepijn 35f9063a6c perf(pi052): fuse text + FAST loss into a single prefix forward
Previously the forward did three backbone passes per training step
when all heads were active: one for flow (via super().forward), one
for text CE, and one for FAST CE. That's ~3× the compute of
flow-only training.

The text and FAST losses share their prefix forward exactly — both
are CE on the LM head, evaluated at different slices of the same
hidden states. Adding FAST tokens after language in the prefix is
bit-equivalent for the text loss because the mask_ar convention in
``make_att_2d_masks`` keeps FAST tokens in a strictly-later causal
block: language tokens never see FAST, so their hidden states are
unchanged.

New ``_compute_text_and_fast_loss``:

  * embeds [images, language] once
  * optionally appends [FAST] (when run_fast is True)
  * one backbone forward
  * slices ``vlm_out[:, -(fast_len + lang_len):-fast_len]`` for
    language hidden states (or ``vlm_out[:, -lang_len:]`` when no
    FAST) → text CE
  * slices ``vlm_out[:, -fast_len:]`` for FAST hidden states →
    FAST CE
  * returns both losses, either of which can be None when the
    caller doesn't want that head.

forward() now calls this fused helper instead of running the two
separate ``_compute_text_loss`` / ``_compute_fast_action_loss``
methods. Those remain in the file for callers that only want one
head (e.g. ablations).

Why flow isn't fused
--------------------

Flow MSE comes from the action-expert (suffix) hidden states, which
attend to the prefix. If we just concat FAST onto the prefix and let
the action expert attend to it, the expert can trivially decode FAST
back to continuous actions — overfitting via shortcut. Preventing
that requires a custom segment-aware attention mask (action expert
can attend to images+language but NOT to subtask/FAST), which is
what pi05_full does in ``compute_layer_complete_knowledge_insulation``.
That's the full-fusion path; deferred as a follow-up since the
text+FAST fusion already recovers most of the compute.

End-to-end forward pass count
-----------------------------

Before: 1 (flow) + 1 (text) + 1 (FAST) = 3 backbone forwards
After:  1 (flow) + 1 (text+FAST fused) = 2 backbone forwards

~33% wall-time reduction per training step when all three heads
are active.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 12:08:34 +02:00
Pepijn 17c0800461 fix(pi052): FAST loss masking + predict_actions gating + smolvla2 review
FAST loss changes
-----------------

1. Gate by ``predict_actions`` (same routing as flow loss). The
   ActionTokenizerProcessorStep tokenises actions for *every*
   sample regardless of which sub-recipe rendered it; for text-only
   recipes (high_level_subtask, memory_update, ...) the action
   tokens are still in the batch but mustn't be supervised. Skip
   the FAST forward+CE entirely when no sample in the batch has
   ``predict_actions=True``.

2. Switch from "multiply-by-mask" masking to ``ignore_index=-100``.
   The old pattern computed per-token CE for all positions, then
   zeroed out invalid ones. Two issues: (a) any out-of-vocab target
   id at a padded position would have crashed cross_entropy before
   the mask got a chance to zero it out, and (b) the pattern is
   needlessly clever. Now ``shift_targets.masked_fill(~mask, -100)``
   followed by ``ignore_index=-100`` cleanly drops invalid positions.
   Matches the smolvla2 text-loss convention.

3. Clean up unused ``bsize`` variable in _compute_fast_action_loss
   and expand the attention-mask docstring with the
   ``make_att_2d_masks`` mask_ar convention spec (causal vs
   bidirectional blocks).

smolvla2 audit (reference review, no code change)
-------------------------------------------------

Compared smolvla2/modeling_smolvla2.py against pi052/modeling_pi052.py
to catch parallel bugs. Findings:

* No ``paligemma.language_model`` vs ``paligemma.model.language_model``
  issue — smolvla2 uses SmolVLM (different class, different attribute
  layout) so the bug doesn't apply.

* ``fill_kv_cache=True`` is correctly passed to smolvla's
  ``vlm_with_expert.forward`` — that class *does* accept the kwarg
  (unlike pi05's PaliGemmaWithExpertModel.forward, which is why
  pi052 must omit it).

* Text-loss alignment is correct: ``_compute_text_loss`` computes
  ``lang_start`` / ``lang_end`` from the known prefix layout
  (``[image_blocks..., lang, state]``) and slices ``prefix_out``
  to just the language positions before applying ``lm_head``. The
  parallel bug I fixed in pi052 (lm_head over the full prefix,
  shape-mismatched against text_labels) was *not* present in
  smolvla2.

* Per-sample flow routing via ``predict_actions``: correctly masks
  per-sample by calling the parent ``forward(..., reduction='none')``
  and applying the predict_actions mask before the mean. pi052 only
  has the batch-level any() gate — a parallel improvement for pi052
  would require modifying PI05Pytorch.forward to support per-sample
  reduction, deferred.

* ``reduction="none"`` returns ``total.expand(bsize)``: identical
  scalar-broadcast limitation in both policies. Acknowledged but
  low priority (only RA-BC weighting uses the per-sample path and
  it's documented as a known approximation in smolvla2).

* Chat tokenizer correctly handles batched/unbatched messages,
  pads with -100 for label positions, builds attention masks. No
  bugs found.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 12:05:37 +02:00
Pepijn c8763e0ad5 fix(pi052): four real bugs in the modeling code + flip defaults
Defaults
--------
* enable_fast_action_loss: False -> True   (match paper §III.B-C Eq.1)
* auto_fit_fast_tokenizer: True -> False   (opt-in; needs base.fit())

Bug fixes
---------

1. Wrong attribute path on PaliGemma. The KI port copied
   pi05_full's ``paligemma.language_model.layers[...]`` literally,
   but the production pi05 wrapper exposes the text model at
   ``paligemma.model.language_model``. With KI enabled, every layer
   would have raised AttributeError on first forward. Fixed all
   references in _compute_layer_ki + _paligemma_forward_ki.

2. ``fill_kv_cache=True`` passed to PaliGemmaWithExpertModel.forward.
   That kwarg is a SmolVLA-only concept; pi05's signature has no
   such argument, so every forward call from pi052 (text loss, FAST
   loss, select_message) would have crashed with TypeError. Dropped
   from all four call sites — pi05's forward already handles the
   cache via past_key_values, and re-forwarding the cumulative
   sequence each step in select_message is fine for our short
   subtask completions.

3. Text-loss shape mismatch. _compute_text_loss applied lm_head to
   the *full* vlm_out (image tokens + language tokens), then tried
   to cross-entropy that against text_labels which only covers the
   language portion — the .view(-1) calls would produce two
   tensors of different lengths and CE would fail. Now slices
   vlm_out to the last text_labels.shape[1] positions before
   running lm_head, matching the [images, language] order
   embed_prefix produces.

4. Dead-code conditional in _paligemma_forward_ki's single-expert
   fallback. The ``if hasattr(...) else self._pi052_orig_forward``
   ternary always took the wrong branch because the attribute is
   always set (we save it in PI052Policy.__init__). Simplified to
   just call self._pi052_orig_forward directly.

After this commit, pi052 should be runnable end-to-end for the
first time with all three loss heads + KI active. Still worth a
100-step smoke test before kicking off a long run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 11:58:40 +02:00
Pepijn 0f4faddc01 feat(pi052): auto-fit FAST tokenizer per-dataset before training
Per Pertsch et al. 2025 (FAST paper, [64] in π0.5) and π0.5 §III.C,
the recommended practice is to *fit* the FAST action tokenizer on
the specific dataset's action distribution rather than using the
published universal codebook off the shelf. The universal tokenizer
works on any 6-DoF action sequence but produces suboptimal
compression, which slows CE convergence and wastes vocab capacity.

New utility ``lerobot.policies.pi052.fit_fast_tokenizer``:

  * samples N action chunks from the LeRobotDataset (default 1024)
  * loads ``physical-intelligence/fast`` as the base
  * calls ``.fit(actions)`` (the AutoProcessor API the HF model card
    documents) — produces a per-dataset codebook
  * saves to ``{cache_dir}/{sha256(dataset, base, n_samples)[:16]}/``
  * returns the local path, ready to feed
    ``ActionTokenizerProcessorStep(action_tokenizer_name=...)``.

Cache is keyed on (dataset, base tokenizer, sample count) so changing
any of them re-runs the fit. Re-running training on the same dataset
re-uses the cache (one fit per dataset per machine).

Auto-fit wiring:

  * PI052Config gets ``auto_fit_fast_tokenizer`` (default True),
    ``fast_tokenizer_cache_dir`` (default ~/.cache/lerobot/...),
    ``fast_tokenizer_fit_samples`` (default 1024).
  * make_pi052_pre_post_processors now takes ``dataset_repo_id``;
    when ``enable_fast_action_loss`` and ``auto_fit_fast_tokenizer``
    are both True and a repo_id is provided, the factory calls
    ``fit_fast_tokenizer`` before constructing the processor step
    and points it at the fitted path.
  * ProcessorConfigKwargs gains ``dataset_repo_id``; the global
    factory dispatch threads it through for ``pi052`` policies.
  * lerobot_train.py populates ``processor_kwargs['dataset_repo_id']``
    from ``--dataset.repo_id`` for pi052 runs.

Failure mode: if ``.fit()`` fails (e.g. older transformers without
the method, or no usable action chunks in the dataset), the factory
logs a warning and falls back to the universal base tokenizer. Train
still works; you just lose the compression improvement.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 11:52:31 +02:00
Pepijn 8dc0af3c28 feat(pi052): FAST action CE loss + knowledge insulation + processor wiring
Three additions ported from ``pi05_full`` on branch ``feat/add-pi05``,
giving pi052 full paper-§III.B-C training capabilities alongside the
recipe-driven text supervision it already had:

* **Config flags** in PI052Config:
    - ``enable_fast_action_loss``  default False
    - ``action_tokenizer_name``    default "physical-intelligence/fast"
    - ``max_action_tokens``        default 256
    - ``fast_skip_tokens``         default 128
    - ``fast_action_loss_weight``  default 1.0
    - ``knowledge_insulation``     default False

* **Processor wiring** (processor_pi052.py): when
  ``enable_fast_action_loss=True``, append an
  ``ActionTokenizerProcessorStep`` after the text tokenizer. It
  tokenises the action tensor with the FAST tokenizer and writes
  ACTION_TOKENS / ACTION_TOKEN_MASK into ``COMPLEMENTARY_DATA`` —
  the existing batch-collation pipeline forwards them as
  ``batch['action.tokens']`` / ``batch['action.token_mask']``.

* **FAST CE loss** (modeling_pi052.py::_compute_fast_action_loss):
  Re-embeds the prefix [images, language], appends the FAST token
  embeddings (using PaliGemma's shared embed_language_tokens),
  forwards through the backbone, slices the trailing
  ``fast_len`` positions, applies the LM head, computes shifted
  next-token CE with the action-mask gating the loss. The loss is
  summed into ``forward()``'s total with ``fast_action_loss_weight``.

* **Knowledge insulation** (modeling_pi052.py::_compute_layer_ki +
  _paligemma_forward_ki): port of pi05_full's per-layer attention
  that detaches VLM K/V on the action-query path so action loss
  gradients cannot flow back into the VLM's K/V projections. Bound
  per-instance via ``types.MethodType`` so it doesn't leak into
  stock ``pi05`` policies that share PaliGemmaWithExpertModel.
  Activated automatically when ``config.knowledge_insulation=True``.

Combined with the existing recipe-driven text head, pi052 now
supports the full three-loss objective:

   L = text_w·H(text) + fast_w·H(FAST actions) + flow_w·MSE(flow)

matching Eq. (1) of arxiv:2504.16054 §IV.D (α=10 by default for the
flow term, 1.0 each for text and FAST CE).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 11:46:21 +02:00
Pepijn 8eba704f15 Revert "chore(training): align pi052_hirobot.slurm with the operator's actual command"
This reverts commit ecbac17196.
2026-05-13 11:03:58 +02:00
Pepijn ecbac17196 chore(training): align pi052_hirobot.slurm with the operator's actual command
Match the working SmolVLA2 launch pattern so the two SLURM scripts
are interchangeable:

  * literal NUM_PROCESSES / BATCH_SIZE / STEPS (no env-var defaults)
  * STEPS=10000 to match the next SmolVLA2 run
  * save_freq=$STEPS so only the final checkpoint is saved
  * dropouts 0.1/0.1/0.1 (mild — matches the operator's iteration)
  * flow_loss_weight / text_loss_weight come from the PI052Config
    defaults (10.0 / 1.0 per Pi 0.5 paper §IV.D), no need to pass
    them explicitly

Job name and policy_repo_id mirror the SmolVLA2 ``_tool-g2`` naming
so the two runs can be compared side-by-side in WandB.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 11:03:09 +02:00
Pepijn 12cce8f2cc fix(smolvla2): align flow_loss_weight default with Pi 0.5 paper's α=10
Pi 0.5 paper §IV.D Eq. (1) sets the loss balance to α=10 between text
CE and flow MSE: actions are the primary output and the flow head
should dominate the gradient signal. SmolVLA2 was defaulting both
weights to 1.0, which inverts that — text CE (~0.5-2.0 nats) ends up
larger than flow MSE (~0.1-1.0), so the action expert gets less
gradient than the LM head despite being the primary task.

Match the paper's split: text_loss_weight=1.0, flow_loss_weight=10.0.
Same as ``pi052`` (the new full reproduction policy).

Also pin the values explicitly in the SLURM launcher so the choice is
visible and overridable per-run rather than buried in the config
default.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 11:02:17 +02:00
Pepijn ef5879a02a feat(pi052): π0.5 v2 — full reproduction of the π0.5 paper recipe
New ``lerobot.policies.pi052`` (parallel to ``smolvla2``) that adds
text-prediction + hierarchical-inference on top of the existing π0.5
implementation. Mirrors the paper's §IV.D dual-head training:

  L = H(text) + α * ‖ω - a - f_θ_action(...)‖²,  α = 10

Components:

  * ``configuration_pi052.py``     thin PI05Config subclass; adds
                                    recipe_path, text/flow loss weights
                                    (default α=10 per paper), prompt
                                    dropout knobs, ``unfreeze_lm_head``.
  * ``text_processor_pi052.py``    PI052TextTokenizerStep — concatenates
                                    rendered messages as ``Role: ...``
                                    plain text (PaliGemma has no chat
                                    template), tokenises with the
                                    PaliGemma tokenizer, builds a label
                                    mask covering supervised target
                                    spans. Includes Pi 0.7 §V.E
                                    per-component prompt dropout.
  * ``processor_pi052.py``         make_pi052_pre_post_processors —
                                    Rename + Batch + Relative +
                                    Normalize + RenderMessagesStep +
                                    PI052TextTokenizerStep + Device.
                                    Falls back to π0.5's plain pipeline
                                    when recipe_path is unset.
  * ``modeling_pi052.py``          PI052Policy(PI05Policy) — re-enables
                                    PaliGemma ``lm_head``, computes
                                    text_loss via CE on the supervised
                                    span, sums with flow_loss in
                                    forward(), and adds select_message
                                    for AR text generation at inference
                                    (same surface as
                                    SmolVLA2Policy.select_message so
                                    SmolVLA2Runtime drives it unchanged).

Plus the supporting plumbing:

  * recipe ``configs/recipes/pi052_hirobot.yaml`` — same Hi-Robot blend
    as smolvla2_hirobot.yaml, with the same ``${subtask}`` /
    ``if_present`` supervision fix (current span at every frame, not
    ``${next_subtask}``).
  * SLURM ``examples/training/pi052_hirobot.slurm`` — full training
    command matching the SmolVLA2 launcher.
  * factory registration: ``--policy.type=pi052`` resolves to
    PI052Policy with the new processor.

Same multi-rate runtime (``lerobot.policies.smolvla2.inference``)
drives this policy too — both expose ``predict_action_chunk`` for the
action expert and ``select_message`` for the LM head.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 10:59:26 +02:00
Pepijn 1d24301b67 chore(training): STEPS=15000 default + dropout walked back to 0.30/0.30/0.20
After _tool-good (2000 steps, 0.50/0.50/0.20 dropout) the LM head's
distribution at position 0 shifted from EOS to subtask-vocabulary
tokens but emitted bag-of-words ("cube arm and") rather than well-
formed sentences. That's the expected mid-fine-tuning phase: token-
level supervision has landed, sequence-level grammar hasn't.

Two changes for the next retrain:

  * STEPS=15000 (from 2000) — chat-pretrained backbones need O(10k+)
    steps to walk their pretraining priors down far enough to commit
    to the fine-tuned distribution structurally, not just at the
    token level. _tool-g2's bag-of-words output proves the model is
    on the right path; it just needs more gradient signal.

  * plan/memory dropout 0.50 -> 0.30 — 0.50 was probably too
    aggressive for a small dataset. Half the training samples had
    crucial context missing, which slows down learning the full
    conditional structure. 0.30 still regularises against prompt
    leakage but lets the model learn proper grammar first; the
    higher dropout can be revisited once the head is solid.

Subtask dropout stays at 0.20 since subtask isn't in the high-level
prompt anyway (recipe fix removed the "Current subtask:" message).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-13 10:46:19 +02:00
Pepijn 3a20ea337e feat(smolvla2-runtime): --text_min_new_tokens / --text_temperature CLI debug knobs
The recipe fix (target=${subtask} instead of ${next_subtask}) shifted
the LM head's failure mode from "emit newlines" to "emit EOS at
position 0". On the new ``_tool-good`` checkpoint inference produces
exactly one token (``<end_of_utterance>``, id 49279) and decodes to
empty. That's the chat-pretrained backbone's short-turn EOS prior
not yet being overridden by 2000 steps of fine-tuning supervision.

Expose three knobs so the operator can probe whether the head has
real subtask-token probability mass *under* the EOS argmax without
recompiling or retraining:

  --text_min_new_tokens=N    suppress EOS for the first N tokens
  --text_temperature=T       sample at temperature T
  --text_top_p=P             nucleus filtering at top-p

These are explicitly off-policy (training was greedy / no min-tokens),
so they shouldn't ship in production runs — but they let us tell
whether the model has *learned* subtask prediction (just under EOS)
or hasn't yet. If forcing min_new_tokens=3 with temperature=0.5
produces a sensible subtask, the model is fine and just needs more
training steps to walk EOS down. If it produces gibberish, training
hasn't progressed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 21:39:33 +02:00
Pepijn b6fb536460 chore(training): bump plan/memory dropout to 0.50 to force vision-grounding
After the recipe fix (target=${subtask} at every frame) the model
can still reach low text_loss by reading the answer off the plan in
the prompt: at training the prompt contains the 6-step plan, and the
current subtask is one of those steps, so the model just learns
"active step N matches subtask N" and never needs to look at the
image. Symptom at inference: subtask string is set but never updates
because the model isn't really conditioning on the visual progress.

Drop plan and memory with p=0.50 each — half of training frames the
prompt is just "${task}" (constant for this dataset) + visual prefix,
which is the only place the answer can come from. Forces the LM head
to actually use vision.

``subtask_dropout`` stays at 0.20 because subtask isn't in the
high-level prompt anymore (recipe fix removed the "Current subtask:
X" message); the knob still affects other sub-recipes that reference
it as context.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 21:31:00 +02:00
pepijn bfd3bb1791 fix(smolvla2): handle batched sample indices in chat tokenizer
Normalize tensor and sequence sample indices before prompt dropout so distributed batched preprocessing does not try to cast full index tensors to scalars.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 16:56:13 +00:00
Pepijn 4908433f9a chore(training): align smolvla2_hirobot.slurm with what's actually run
Match the operator's current training command for the _tool6 retrain:

  * default DATASET / POLICY_REPO_ID / JOB_NAME point at the tool6
    iteration (super_poulain_full_tool3 → smolvla2_hirobot_super_poulain_tool6)
  * STEPS default 2000 (short enough to iterate; bump to 10k for full)
  * save_freq=$STEPS so the only checkpoint is the final one
  * OUTPUT_DIR includes step count so successive runs don't clobber
  * Drop the wider augmentation envelope I added earlier — back to
    default ColorJitter ranges (brightness ±20% etc) since the
    high_level_subtask recipe fix (current-subtask supervision) is
    expected to fix the LM-head collapse on its own; the augmentation
    is just the standard regulariser, not a load-bearing widener.
  * prompt-dropout fractions stay at the original 0.15 / 0.15 / 0.20.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:45:38 +02:00
Pepijn 6ce1f36002 fix(smolvla2): supervise high-level head with *current* subtask at every frame
The high_level_subtask recipe targeted ``nth_next(style=subtask, offset=1)``,
which on the last span of any episode resolves to None. The recipe had no
``if_present`` guard on the target, so the renderer emitted an empty
assistant turn and cross-entropy supervised the model on the chat
template's structural newlines (``\n``). Across the dataset this trained
the LM head's argmax at position 0 to collapse to ``\n`` whenever no
transition was imminent (i.e. most frames). Visible failure mode at
inference: the head emits 40+ newlines + ``<end_of_utterance>`` every
chunk boundary while the action expert keeps working — confirmed by
running the dry-run on dataset frame 0 with the dataset's own image
and seeing the same ``\n × 44`` collapse.

Switch to the Pi 0.5 / Pi 0.7 supervision pattern: at every frame, the
assistant target is the *current* active subtask span text (via
``${subtask}`` → ``active_at(t, style=subtask)``). Always non-empty,
always scene-grounded, ``if_present: subtask`` skips frames with no
active span instead of emitting a degenerate empty turn.

Runtime callsite update: ``_msgs_for_subtask`` no longer feeds a
"Current subtask: X" user message into the prompt (that would be
circular — we'd be telling the model the answer). Transition
detection moves into the runtime — when the predicted subtask differs
from ``state['current_subtask']``, the existing ``set_if_changed``
path fires ``subtask_change`` and downstream memory updates. Same
event surface, supervision target is now always meaningful.

Requires re-annotating the dataset and retraining for the fix to land
in the checkpoint, but the recipe + runtime change is what enables it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:42:59 +02:00
Pepijn 731576be80 chore(smolvla2-runtime): auto-fire one tick at dry-run startup
Previously the dry-run REPL only ticked on user input (empty Enter
just redrew), so the bisection test "does the LM head produce text on
start_frame=0?" required typing something arbitrary to drive a tick.
Just run ``step_once`` at startup — the obs diagnostic *and* the
subtask gen both fire automatically, the diag row populates, and the
operator can read the result before pressing any key.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:34:42 +02:00
Pepijn 47fb8318b1 chore(training): widen augmentation envelope after live-robot diagnostic
The tensor-level comparison between dry-run (dataset frame) and live-
robot inference proved the runtime is bug-free — same shape, dtype,
device, channel order, batch dim, and normalization on both paths.
The remaining variable: front-camera mean brightness was 0.26 live vs
0.39 on the dataset frame, ~33% darker. Training augmentation only
covered ±20% brightness, so the live scene sits just outside the
supervised envelope and the LM head collapses to its dominant prior.

Widen the augmentation knobs for the next retrain:

  * brightness    0.8–1.2  → 0.5–1.6   (covers ~30% darker / 60% lighter)
  * contrast      0.8–1.2  → 0.6–1.5
  * saturation    0.5–1.5  → 0.3–1.7
  * hue          ±0.05    → ±0.10
  * affine        ±5°/±5%  → ±15°/±15% (covers cube placement / camera drift)
  * max_num_transforms 3 → 4

And bump prompt-component dropout (subtask 0.20 → 0.30) so the LM
can't lean on stale memorised plan/memory at inference.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:25:41 +02:00
Pepijn 53172873e3 chore(smolvla2-runtime): probe obs once at dry-run startup
The dry-run REPL only fires a tick when the user types, so the
``_log_obs_tensors_once`` diagnostic never reached stdout (the
provider was never called). Probe the provider once at startup —
the result is discarded; we only care about the obs log it triggers.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:21:58 +02:00
Pepijn fcdae0ce8e chore(smolvla2-runtime): tensor-level obs print for both inference paths
Helper that prints (once per provider lifetime) every
``observation.*`` tensor the policy is about to see, with its shape,
dtype, device, and per-channel min/max/mean/std. Wired into both the
dry-run dataset path and the live-robot path.

Now we can bisect train/inference mismatch *at the tensor level* —
if the same checkpoint produces coherent text on one path's tensors
and ``\n`` on the other's, and the printed tensor stats differ
materially, the bug is in the observation prep, not in the model or
the training distribution.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:19:18 +02:00
Pepijn 4852b9f952 feat(smolvla2-runtime): --dataset.augment_at_inference for the bisection test
Apply the training-time torchvision-v2 ColorJitter / SharpnessJitter /
RandomAffine pipeline to dataset frames in dry-run, so we can isolate
whether the LM head's collapse to '\n' on live frames is:

  * pure scene-content OOD (unaugmented dataset frames work, mildly
    augmented ones still work — model has learned the augmentation
    distribution, only fails when the scene content itself diverges)
  * hyper-specific memorisation (dry-run with augmentation also
    collapses to '\n' — head is nailed to the exact unperturbed
    training samples and only the retrain helps)

Usage:

  lerobot-smolvla2-runtime --no_robot --policy.path=... \
    --dataset.repo_id=... --dataset.episode=0 \
    --dataset.start_frame=1000 \
    --dataset.augment_at_inference

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:14:57 +02:00
Pepijn 0410705aff chore(smolvla2-runtime): print live state vector once at startup
So the operator can compare live joint values to the dataset's
``observation.state`` mean/std and spot when the robot's home pose is
several σ off the supervised support region. State OOD is the
remaining viable hypothesis for why the live LM head collapses to
``\n`` even though images are pixel-shape-matched.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:12:27 +02:00
Pepijn 398a8cf730 chore(smolvla2-runtime): log first-tick resize so train/inference match is verifiable
Print one warning the first time the robot observation provider runs
through, showing live camera resolution and the dataset's training
resolution, plus whether we resized. Lets the operator confirm at a
glance that the visual prefix really is being fed at the same shape
the model saw at training — instead of guessing whether the resize
fired silently.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 18:06:00 +02:00
Pepijn ab5c1dc392 fix(smolvla2-runtime): match training visual distribution on robot frames
Root cause for the LM head's empty-completion symptom on the live robot
(while the same checkpoint produced sensible subtask/plan/memory in
``--no_robot`` dry-run on dataset frames): the camera observation was
flowing into the model at its native resolution. A Mac/USB webcam
hands us 1280×720 or 1920×1080; the dataset was recorded at the
feature schema's ``observation.images.*['shape']`` resolution
(typically 480×640). SmolVLA's internal ``resize_with_pad(512, 512)``
*does* fit both — but with very different pad geometry, so visual
tokens at each tile carry different content than at training. Action
expert tolerates this; the tightly-supervised LM head goes OOD and
the head's distribution at position 0 collapses to its dominant mode
(``\n`` ×N then ``<end_of_utterance>`` for this checkpoint).

The fix: in ``_build_robot_observation_provider``, pre-compute the
camera-key → (H, W) target from ``ds_features`` and ``cv2.resize``
each live frame to that shape before tensorising. The downstream
``resize_with_pad`` then sees the same input geometry as training and
the LM head returns to producing readable subtask text under plain
greedy decoding — the same as dry-run.

Also drops the inference-time patches (``min_new_tokens``,
``temperature``, ``top_p`` overrides) on the four high-level callers.
They were band-aids around the visual-distribution shift, not a real
LM problem, and they drift inference off the training distribution.
Greedy argmax is what training matched. The ``select_message``
signature still accepts the knobs for callers that want them.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:59:24 +02:00
Pepijn 1292304c42 fix(smolvla2): suppress all special tokens during min_new_tokens window
Previous attempt only masked the tokenizer's eos_token_id during the
min_new_tokens prefix. The empty-completion symptom persisted because a
memorised SmolVLM head doesn't just want EOS — its top-1 at position 0
is *some* special token, and when EOS is masked the argmax shifts to a
sibling (``<|im_end|>``, ``<image>``, ``<fake_token_around_image>``,
``<row_X_col_Y>``, …). Those tokens survive generation but then get
stripped by ``decode(skip_special_tokens=True)``, so the runtime still
saw ``last_raw='(empty)'`` every chunk boundary.

Mask the full ``tokenizer.all_special_ids`` set instead. Forces the
head to commit to a normal vocabulary token before it can close or
quietly poison the turn.

Also: when decode returns empty but tokens *were* generated, expose
the raw token ids and the special-tokens-included decoded string via
``policy._last_select_message_debug``. The runtime surfaces this in
the scrollback so the operator can see what the head is actually
emitting — distinguishing "head EOS-ing" from "head emitting image
placeholders" from "head emitting chat-template fragments".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:49:53 +02:00
Pepijn b95eebff77 fix(smolvla2): force min_new_tokens + sampling so memorised LM emits something
Real-robot run confirmed the LM head is producing 0 tokens at every
chunk boundary (empty:N counter climbing, no exception in scrollback):
the model EOS-es at decode step 0. That's the memorisation collapse —
training reached text_loss=6e-6 by overfitting one trajectory whose
supervised subtask turn ended in EOS, and at inference the head's
argmax for token 0 is EOS regardless of the actual frame.

Two changes in select_message:

  * ``min_new_tokens`` parameter masks the EOS logit to -inf until at
    least N real tokens have been decoded. Without this the head's
    "EOS first" prior produces an empty completion every single time.

  * The runtime callers now pass ``min_new_tokens=5..10`` plus
    ``temperature=0.4..0.5`` + ``top_p=0.9``. Sampling at moderate
    temperature with nucleus filtering also helps break the greedy
    argmax collapse — when the model has memorised one continuation,
    greedy keeps replaying it; nucleus sampling forces it to commit
    to *some* coherent continuation that's well-supported by the
    prefix even when greedy's top-1 is degenerate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:48:08 +02:00
Pepijn fbcac95662 feat(smolvla2-runtime): scrollback in autonomous panel + empty-gen counter
Two improvements for diagnosing why ``last_raw`` stays empty:

1. The autonomous panel-redraw thread calls console.clear() every
   0.5 s, wiping any log lines the runtime printed since the last
   redraw. So warnings from generation (``[warn] subtask gen failed:
   ...``, ``[info] subtask gen rejected (gibberish): ...``) flashed
   for milliseconds and disappeared, leaving the operator blind.

   Capture log_lines from each tick into a bounded scrollback
   (last 12 entries) and render them inside the panel itself, below
   the diag row. They now stick across redraws until rotated out.

2. ``empty`` counter for subtask gen. Persistent empty completions
   are their own failure mode — the LM head EOS-es immediately from
   the chat-template generation prompt, distinct from "generated
   something but filter rejected it". The diag row now reads:

     subtask diag    repeat:0  gibberish:0  empty:14  last_raw: '(empty)'
                                            ^^^^^^^
   plus a periodic log line every 10 empties so the cause is also
   surfaced in the scrollback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:42:13 +02:00
Pepijn b9db4d21a2 fix(smolvla2): high-level steps must run before LowLevelForward refills
Both HighLevelSubtaskFwd and LowLevelForward are gated on
'action queue is empty'. With LowLevelForward listed first, it refilled
the queue on the empty-queue tick before HighLevelSubtaskFwd got to
check — so the gate I added in the previous commit made the high-level
step a permanent no-op after the initial bootstrap. Visible symptom:
subtask string never advances past whatever bootstrap seeded, no
subtask_change events, memory stays unset, and the new overfit
diagnostics never appear on the panel because last_subtask_raw is
never written.

Move all high-level steps (subtask, memory, interjection, vqa) ahead
of LowLevelForward. On an empty-queue tick the subtask refreshes
first, the new string flows into the next chunk's prompt, then
LowLevelForward generates the chunk, then DispatchAction drains it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:38:06 +02:00
Pepijn aecb80a9d2 feat(smolvla2-runtime): overfit/memorisation diagnostics on the panel
The autonomous-mode panel now surfaces what the model is *actually*
producing at every chunk boundary, not just what got accepted:

  * last_subtask_raw       most recent generation (accepted or not)
  * subtask_repeat_count   times the same accepted string regenerated
  * subtask_gibberish_count rejections by the gibberish filter
  * memory_gibberish_count / plan_gibberish_count for the other heads

These let the operator see memorisation collapse without scrolling
back through logs:

  subtask diag    repeat:8  gibberish:0  last_raw: '<same string>'
                  ^^^^^^^^^^ → model can't move past current phase

  subtask diag    repeat:0  gibberish:14  last_raw: 'Ass:::'
                  ^^^^^^^^^^^^^^^^^^^^^^ → LM collapsed to template salad

Also silences the per-action ``Relative goal position magnitude had
to be clamped`` warning. The clamp fires every dispatch tick when the
model emits stale joint targets, flooding the panel at ctrl_hz=30.
Replaced the bare ``logging.warning`` call in robots/utils.py with a
module logger so it can be selectively raised to ERROR. Operators
who need the per-tick clamp detail can use ``-v``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:31:04 +02:00
Pepijn c98c695127 feat(smolvla2-runtime): 'rephrase:' prefix to swap task string in place
Adds a third stdin channel alongside 'task:' and bare interjections:

  rephrase: <text>

Swaps state['task'] with the new string while preserving plan/memory/
subtask. Lets the operator probe how robust the model is to wording
variations of the same task — the trained augmentation provided
n_task_rephrasings≈30 task wordings per dataset task, and this is the
direct way to exercise that distribution at inference without
generating a fresh plan via user_interjection_response.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:26:59 +02:00
Pepijn d528078aca fix(smolvla2-runtime): allow task switching mid-run via 'task:' prefix
Both stdin handlers (autonomous mode and rich REPL) gated 'task:' to
'only if no task is set yet' — once the initial task existed, typing
'task: <new task>' silently fell through to the interjection branch.
Make 'task:' always override the active task and clear stale
plan/memory/subtask so the next high-level pass regenerates context
from scratch for the new task.

For rephrasings within the same task, the interjection path
(user_interjection_response recipe) is still the right channel — it
refreshes the plan and emits a paired <say> in one trained call.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:24:16 +02:00
Pepijn a648da0455 fix(smolvla2): unblock action dispatch when high-level LLM stalls loop
The runtime is single-threaded. `HighLevelSubtaskFwd` at HzTrigger(1.0)
fires every loop iteration on MPS because each `select_message` call
takes ~2 s, longer than its 1/hz period. The whole tick stretches to
~2.5 s, so `DispatchAction` (HzTrigger 30) only pops a single action per
loop iteration — the queue drains at ~0.4 actions/sec instead of 30 and
the robot barely moves between chunk refreshes.

Two changes, both purely about scheduling — no threading:

* Gate `HighLevelSubtaskFwd` to fire only when the action queue is
  empty, matching `LowLevelForward`'s refresh condition. The slow LLM
  call now happens during the "think" phase between chunks, not on
  every dispatch tick. Restores a clean sense → think → act cycle.

* `DispatchAction` catches up via wall-clock: when the trigger fires
  after a stall, pop `round(elapsed * hz)` entries and send only the
  most recent. Open-loop chunks are timestamped at ctrl_hz; sending
  stale joint targets one-by-one would just lag the robot further
  behind. The dynamixel smooths to the latest goal anyway.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:23:09 +02:00
Pepijn d866c2c9fd fix(smolvla2): only regenerate chunk when queue is fully drained
The previous refresh threshold (queue > chunk_size // 2) made each
new chunk *telescope* past the previous one: at queue=25, we kicked
off a new chunk forward from the current observation, but by the
time the new chunk's first action was actually dispatched, the
robot had executed the remaining 25 actions of the previous chunk
— so the new chunk was planned from an observation 25+ steps stale.

Canonical sense → think → act loop: execute the full chunk, then
re-observe and replan. Refresh only when the queue is empty. Every
step of every chunk still gets dispatched to the robot (no
behaviour change there), but each chunk is now planned from an
observation that's at most one chunk's worth of dispatch latency
old, not "previous chunk's worth of stale state on top of that".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 17:15:02 +02:00
Pepijn 01e2228b24 feat(smolvla2): per-component prompt dropout + augmented training script
Two complementary regularisers to attack the
``text_loss=6e-6 = memorised one dataset`` failure mode that's
making the model collapse on real-robot input:

1. **Per-component prompt dropout** (Pi0.7 §V.E / plan's
   ``feat/pi05-prompt-dropout`` follow-up).
   ``SmolVLA2ChatTokenizerStep`` gains
   ``plan_dropout_prob`` / ``memory_dropout_prob`` /
   ``subtask_dropout_prob`` knobs (default 0.0 — opt-in). At training,
   non-target messages whose rendered content starts with
   ``Plan:`` / ``Memory:`` / ``Current subtask:`` etc. are dropped
   with their respective probability before tokenisation, with a
   deterministic per-sample RNG keyed off the dataset ``index``.
   ``target_message_indices`` is re-mapped so the supervision still
   lands on the right turn. Forces the model to handle missing
   plan/memory/subtask context — directly attacks the real-robot
   collapse where a stale or empty plan field puts the prompt OOD.

   Surfaced on ``SmolVLA2Config`` as three floats so they're
   ``--policy.<knob>=<value>``-controllable from the train CLI;
   plumbed through ``make_smolvla2_pre_post_processors``.

2. **Image augmentation** is already wired in lerobot via
   ``--dataset.image_transforms.enable=true`` (torchvision v2
   ColorJitter + SharpnessJitter + RandomAffine, default 3 of 6
   sampled per frame). No code change needed — just a CLI flag.

``examples/training/smolvla2_hirobot.slurm`` shows the full
training command with both enabled. Drop-in replacement for the
ad-hoc SLURM script Pepijn was using locally; same args, plus the
three dropout probs and the image-transforms flag.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 15:52:32 +02:00
Pepijn c36de3a3e8 fix(smolvla2): enqueue full chunk via predict_action_chunk
``LowLevelForward`` was calling ``select_action()`` once per
``chunk_hz`` tick. SmolVLA's ``select_action`` is a thin queue-pop:
it returns one action per call and only re-runs the expensive
flow-matching forward when its private internal queue empties.
Result: we got one action back per chunk_hz tick (1Hz default),
``DispatchAction`` at ctrl_hz=30 popped it instantly, then queue
sat empty for ~1s waiting for the next tick. Net throughput was
1 dispatched action/sec instead of the 30 we wanted.

Switch to ``predict_action_chunk`` and enqueue every step of the
returned ``(batch, n_action_steps, action_dim)`` chunk. Refresh
only when the queue is below half a chunk so we don't burn one
flow-matching forward per chunk_hz tick — saves ~5x inference cost
on this hot path. At ctrl_hz=30, chunk_size=50, the queue drains
in ~1.7s before the next refresh, giving smooth dispatch at the
control rate the robot was trained on.

Side effect: ``state['last_chunk_size']`` records how many actions
the most recent chunk produced — useful for the panel later if we
want to surface "chunks generated" alongside "dispatched".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 15:27:23 +02:00
Pepijn cbfaf2c544 feat(smolvla2): action-dispatch counter + tighter gibberish filter
Real-robot run was unreadable for two reasons:

1. The panel surfaced ``queued actions: 0`` (always zero — dispatch
   pops faster than chunk_hz generates) and gave no signal that
   actions were actually reaching the robot. The only sign of life
   was the safety-clamp warning lines scrolling past.

2. The text head consistently collapses to ``the`` / ``Ass``
   fragments on real-camera input (memorisation wall). The old
   gibberish filter caught ``":":":"`` JSON salad but let
   single-token fragments through, and the ``[info] subtask gen
   produced no text this tick`` line flooded the panel every second.

Changes:

  * ``DispatchAction`` bumps ``state["actions_dispatched"]`` each
    tick; panel renders it next to queue depth. Operator can see
    the policy IS issuing actions even when text is broken.
  * ``_looks_like_gibberish`` now also rejects:
    - too few unique alphabetic tokens (``the``, ``the the``, ...)
    - chat-template marker leakage (``Assistant:``, ``Ass\\n::``)
    catching the actual failure mode on real-robot frames.
  * Gibberish rejections log only the first occurrence + every 30th
    after that, with a count, so the panel stays legible.
  * Empty completions no longer log at all (was every tick).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 15:22:36 +02:00
Pepijn d0278ea093 feat(smolvla2): render state panel in autonomous mode too
Dry-run REPL had a clean ANSI-clear-+-rich-panel layout via
``_redraw`` showing task / subtask / plan / memory / queued-actions /
pending-tool-calls; autonomous mode just had bare ``> `` plus log
lines scrolling past the user. Same data, two presentations.

Extract ``_make_state_panel_renderer(runtime, mode_label=...)`` and
use it from both ``_run_repl`` (called per user input) and
``_run_autonomous`` (called both on user input *and* on a 0.5s
background timer so subtask / plan / memory refreshes from the
runtime's own loop become visible without the user typing anything).
Title bar shows ``dry-run`` vs ``autonomous`` so it's obvious which
mode you're in.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 15:16:28 +02:00
Pepijn 15f6b08b0e fix(smolvla2): use canonical _strip_lerobot_blocks for inference msgs
Training tokenises messages through ``_strip_lerobot_blocks`` (in
``chat_processor_smolvla2.py``), which normalises every variant of
``message['content']`` into the ``[{type:text, text:...}]`` list shape
SmolVLM's chat template expects:

  * ``list[block]`` → keep text blocks, drop images
  * ``None``        → ``[{type:text, text:""}]``
  * ``str`` / other → ``[{type:text, text:str(content)}]``

Inference was doing a partial inline conversion that only handled the
``str`` case — ``None`` and pre-formatted ``list`` content slipped
through unchanged. ``memory_update``'s ``Previous memory: ...``
assistant turn ends up with ``None`` content when there's no prior
memory, which then renders as no-content / role-marker-only and the
model hallucinates ``Assistant:`` fragments. Subtask gen got further
because its prompt always has at least the task string.

Reuse ``_strip_lerobot_blocks`` directly. Now the inference prompt
shape matches the exact tokenisation training did — no more "trained
on shape X, asked to predict shape Y" mismatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 15:07:39 +02:00
Pepijn fc715db4a3 fix(smolvla2): coerce str content to list-of-blocks for chat template
SmolVLM's chat template (and many other multimodal templates) declares
``message['content']`` as a list of typed blocks and iterates it
expecting dicts with a ``'type'`` field:

    {% for line in message['content'] %}
      {% if line['type'] == 'text' %}{{ line['text'] }}
      {% elif line['type'] == 'image' %}{{ '<image>' }}
      {% endif %}
    {% endfor %}

When the caller passes ``content`` as a plain ``str`` (which we did
throughout ``_msgs_for_subtask`` / ``_msgs_for_memory`` etc.), Jinja
silently iterates the string character-by-character. ``'P'['type']``
returns nothing; neither branch fires; *no text tokens get emitted*.
The model receives a prompt containing only role markers
(``User:<end_of_utterance>\nAssistant:``) and predictably continues by
emitting ``Assistant:`` fragments — the gibberish ``subtask: Ass\n::``
on the runtime panel.

Before calling ``apply_chat_template``, walk the messages and rewrite
any string ``content`` into ``[{'type': 'text', 'text': content}]``.
The template's text branch then fires correctly and the model sees
the actual user/assistant text, not just structural tokens.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 15:01:53 +02:00
Pepijn fe4bd2b6ba fix(smolvla2): pass flat batch dict to preprocessor (no manual wrap)
``PolicyProcessorPipeline.__call__`` already wraps its input via
``to_transition`` (defaulting to ``batch_to_transition``) before
running the steps, and unwraps via ``to_output`` (defaulting to
``transition_to_batch``) afterwards. The input format is therefore a
*flat batch dict* keyed by ``observation.*`` / ``action`` / etc., not
an ``EnvTransition``.

Previous attempt pre-wrapped the observation into a transition with
``TransitionKey.OBSERVATION`` as the key, then handed *that* to the
pipeline — which fed it to ``batch_to_transition``, which looked for
top-level ``observation.*`` entries, found none (they were nested
inside the enum key), and produced an empty observation. Every step
then bailed with ``ObservationProcessorStep requires an observation
in the transition.``

Pass the flat dict from ``build_inference_frame`` straight to the
preprocessor — it does the wrap/unwrap itself.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:54:48 +02:00
Pepijn 3f7436ff8a fix(smolvla2): use TransitionKey enum (not .value) as transition keys
``EnvTransition`` is declared as a ``TypedDict`` keyed by
``TransitionKey.OBSERVATION.value`` (the string ``'observation'``),
but every concrete ``ProcessorStep`` in the pipeline indexes the
transition with the enum *member* (``transition[TransitionKey.
OBSERVATION]`` / ``transition.get(TransitionKey.OBSERVATION)``).
Those are two different keys in a Python dict — string key vs enum
key — so steps couldn't find the observation we'd placed under the
string variant, and bailed every tick with
``ObservationProcessorStep requires an observation in the
transition``.

Build the transition with the enum members directly. Matches how
``BatchProcessor``, ``RelativeActionProcessor``, ``HilProcessor``,
etc. read the dict.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:50:22 +02:00
Pepijn 992d13d4e9 fix(smolvla2): use build_inference_frame for raw robot observations
``robot.get_observation()`` on omx_follower (and most lerobot robots)
returns:

  * per-joint scalar floats with ``.pos`` suffix
    (``shoulder_pan.pos: 0.123``, ``shoulder_lift.pos: 0.456``, ...)
  * per-camera ndarrays keyed by the camera config name (``wrist:
    ndarray(H,W,3)``)

But the trained policy expects:

  * single ``observation.state: tensor[N_joints]`` vector
  * image keys prefixed: ``observation.images.<cam_key>:
    tensor[1, 3, H, W]``

``prepare_observation_for_inference`` only handles the tensor /
batch-dim / device step — it crashes on scalar floats with
``expected np.ndarray (got float)``. The right helper is
``build_inference_frame`` which uses the dataset's feature schema
(``ds_meta.features``) to:

  1. extract the right raw keys per dataset feature,
  2. fold ``shoulder_pan.pos`` / ``shoulder_lift.pos`` / ...
     into a single ``observation.state`` ndarray,
  3. prefix camera keys with ``observation.images.``,
  4. delegate to ``prepare_observation_for_inference`` for the
     tensor / batch / device step.

Pass ``ds_meta.features`` into the observation provider and switch
to ``build_inference_frame`` when available; fall back to the bare
``prepare_observation_for_inference`` only when no dataset is
provided (rare — autonomous mode already requires it).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:47:59 +02:00
Pepijn afe40a016b fix(smolvla2): wrap robot obs in EnvTransition before preprocessor
The policy preprocessor pipeline is transition-shaped — its steps
read ``TransitionKey.OBSERVATION`` off an ``EnvTransition`` dict, not
a flat ``RobotObservation`` dict. Passing the raw observation through
made every step bail with
``ObservationProcessorStep requires an observation in the transition``,
which the runtime swallowed at warning level. ``select_message`` then
got called with no ``observation.images.*`` features and crashed
with ``All image features are missing from the batch``.

Mirror ``lerobot-record``'s preamble:
  1. ``prepare_observation_for_inference`` → numpy → torch, ``CHW``
     image layout, ``[0,1]`` scaling, add batch dim, move to device.
  2. Wrap into an ``EnvTransition`` (``{TransitionKey.OBSERVATION.value:
     ...}`` plus ``COMPLEMENTARY_DATA: {}`` and ``None``s for the rest)
     so transition-aware steps see the keys they expect.
  3. Run preprocessor.
  4. Unwrap the transition's ``OBSERVATION`` slot to get the final
     flat dict the policy's ``select_action`` / ``select_message``
     consume.

Image features now reach the policy; the autonomous loop produces
real actions instead of swallowing warnings every tick.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:44:24 +02:00
Pepijn 41095e3cc3 fix(smolvla2): instantiate CameraConfig subclasses from JSON dicts
``--robot.cameras`` parses the JSON into ``dict[str, dict]``, but
``RobotConfig`` expects ``dict[str, CameraConfig]`` — each inner
value must be the actual ``CameraConfig`` subclass instance for the
chosen backend (e.g. ``OpenCVCameraConfig``). Passing raw dicts
blew up in ``RobotConfig.__post_init__`` with
``AttributeError: 'dict' object has no attribute 'width'`` when it
iterated cameras and tried to read attributes.

Look up the right subclass per-camera by its ``"type"`` field via
``CameraConfig.get_choice_class(...)`` (mirroring the lazy-import
dance we already do for ``RobotConfig``: eagerly walk
``lerobot.cameras``'s submodules so the registry is populated
before lookup). Construct an instance with the rest of the dict's
fields. On an unknown camera type, raise a clean ``ValueError``
listing the available choices.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:39:28 +02:00
Pepijn e0fa957569 fix(smolvla2): eagerly import robot submodules before get_choice_class
``RobotConfig._choice_registry`` is populated as a side-effect of
each robot's ``@RobotConfig.register_subclass`` decorator running,
and those decorators only fire when the corresponding
``lerobot.robots.<name>`` module is imported. The package's
``__init__.py`` doesn't import them — instead ``make_robot_from_config``
does it lazily in its big if/elif chain.

``_build_robot`` jumped the gun: called ``RobotConfig.get_choice_class
(robot_type)`` before any robot module had been imported, so the
registry was empty and every ``--robot.type=<X>`` produced
``KeyError: 'X'`` (e.g. ``KeyError: 'omx_follower'``).

Walk ``lerobot.robots``'s submodules via ``pkgutil.iter_modules`` and
``importlib.import_module`` each one before the lookup. ~200ms on the
first invocation, negligible for an autonomous run. On a real
``KeyError`` (typo / unsupported robot), raise a clean ``ValueError``
listing the registry's available choices instead of a bare KeyError.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 14:31:58 +02:00
Pepijn c661d81409 fix(smolvla2): use RobotConfig.max_relative_target, drop --max_action_norm
The hand-rolled action-norm safety clip duplicated what every
``RobotConfig`` already exposes — ``max_relative_target`` — and at
the wrong layer (after postprocess but before send_action, instead
of inside the robot driver where every other lerobot entry point
puts it). The norm clip also rejected entire actions instead of
clipping per-motor relative motion, so a single rogue joint would
kill the whole tick.

Replace with ``--robot.max_relative_target``: a string parsed as
either a bare float (uniform per-motor cap) or a JSON object
mapping motor name → cap. Passed through to
``RobotConfig(max_relative_target=...)`` at robot construction;
the driver's ``send_action`` clips each commanded joint position
relative to the current measured one before issuing it on the bus —
same behaviour ``lerobot-record`` ships.

Also bump ``--chunk_hz`` default from ``4.0`` to ``1.0``. One new
chunk per second is what the trained checkpoint can comfortably
keep up with on common hardware and gives smoother motion than
sub-second chunk regenerations (no RTC interpolation between
chunks yet).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-12 11:41:57 +02:00
Pepijn 965d42825f review: skip-count fix, atomic writes, dedupe span reconstruction, role guards
**#1 Plan-update phase reports correct skip count.**
``_run_plan_update_phase`` only ran ``run_plan_updates`` for episodes
with at least one interjection but hardcoded ``episodes_skipped=0``.
The summary undercounted skipped episodes. Now returns
``len(records) - processed`` so processed + skipped == total.

**#2 ``run_hf_job.py`` installs ``openai``.**
The ``CMD`` block does ``pip install --no-deps lerobot[branch]`` then
explicitly lists transitive deps. ``openai`` was missing — and since
``VlmConfig.backend`` defaults to ``"openai"``, the job would have
``ImportError``'d when ``vlm_client._make_openai_client`` ran.

**#3 Dedupe subtask-span reconstruction.**
Module 1's ``_reconstruct_subtasks_from_rows`` (no ``and spans`` guard)
and Module 2's ``_read_subtask_spans`` (with the guard) had near-
identical logic. Promoted to ``reconstruct_subtask_spans`` in
``reader.py`` using the safer guarded form. Both modules now import
the single helper.

**#5 Atomic staging.py JSONL writes.**
Mirroring the parquet-writer fix from an earlier review round:
``EpisodeStaging.write`` now writes to a sibling ``.tmp`` and
``Path.replace`` atomically. A crash mid-write can no longer leave a
half-written JSONL that ``read()`` would then fail to parse.

**#6 Atomic ``info.json`` write.**
Same pattern in ``executor._ensure_annotation_metadata_in_info`` —
``info.json`` is load-bearing for dataset metadata, so partial writes
brick the dataset.

**#7 Writer's role-key guard.**
``_normalize_persistent_row`` and ``_normalize_event_row`` accessed
``row["role"]`` directly while every other field used ``.get()``.
Pre-validate ``"role" in row`` and raise a friendly ``ValueError``
naming the row, so a future module that accidentally drops ``role``
fails with a triagable message instead of a bare KeyError deep in the
writer.

**#8 Last subtask span's ``end`` extends to episode end.**
``reconstruct_subtask_spans`` (the new shared helper) takes an optional
``episode_end_t``. When provided, the final span's ``end`` is closed
to that timestamp instead of equalling its own ``start`` (zero
duration). Both Module 1's plan-update pass and Module 2's interjection
anchoring pass ``record.frame_timestamps[-1]``, so downstream "current
subtask at refresh_t" lookups no longer miss refreshes that land
inside the final span.

Sweep: 66 passed, 0 failed. Pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 12:18:09 +02:00
Pepijn 1238a0cd47 test(annotate): unstale the two failing module tests
Both tests were stale relative to design changes that landed earlier on
this branch. Update the tests to match the current production contract.

**``test_module1_attaches_video_block_to_subtask_prompt``**

The test took ``captured[0]`` and asserted on its content blocks, but
Module 1 issues several sub-prompts and the rephrasings call (which is
text-only, no video block) usually lands first. Two fixes:

* The test's intent is "the subtask prompt carries the video block" —
  not "the first prompt carries it". Pick the call by content
  (``"atomic subtasks"`` keyword in the text block) so the test is
  resilient to future reordering of unrelated sub-prompts.
* Set ``n_task_rephrasings=0`` so the rephrasings call is skipped
  entirely — keeps the test focused on ``_generate_subtasks``.

**``test_module2_mid_episode_emits_paired_interjection_and_speech``**

Two issues both rooted in design changes on the branch:

1. ``InterjectionsAndSpeechModule._mid_episode_interjections`` now
   anchors interjections on subtask boundaries from Module 1's staging
   tree, bailing out with zero rows when no spans exist. The production
   executor runs Module 1 first; the test ran Module 2 in isolation.
   Reproduce the contract by seeding two ``style=subtask`` rows in the
   staging before calling Module 2 — gives it the single ``0 → 1``
   boundary it needs.
2. The test's stub responder used the marker ``"ONE realistic
   interruption"`` to match the interjection prompt, but that string is
   from a previous prompt version. The current
   ``module_2_interjection.txt`` says ``"Write ONE interjection..."`` —
   the old prompt asked for counterfactual interjections (e.g. "skip the
   wipe"), the new one anchors on the upcoming subtask. Marker updated
   to ``"Write ONE interjection"``; canned response wording aligned to
   the new design.

Sweep on the language stack: 66 passed, 0 failed (was 64 passed, 2
failed). Pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 11:59:27 +02:00
Pepijn 53c7641885 review: fix dead-code bug, add thread safety, atomic writes, smaller cleanups
**Critical: video_for_episode was unreachable dead code.**
``video_for_episode`` was indented inside ``_decode_pyav_direct``, after
its ``return`` statement — Python parsed it as a nested function that
never executed. Module 1's ``_episode_video_block`` calls
``self.frame_provider.video_for_episode(record, target_count)`` on the
``use_video_url=False`` path, which would have AttributeError'd on any
real dataset. Tests passed only because they used ``_StubFrameProvider``
/ ``_NullProvider`` which have the method. Moved it to be a proper
method of ``VideoFrameProvider`` (right after ``frames_at``).

**Thread safety on VideoFrameProvider.**
The executor runs Module 1/2/3 phases under a ``ThreadPoolExecutor``, so
the per-instance ``_cache`` dict and the one-shot ``_warned_decode_fail``
flag were exposed to concurrent reads/writes. Added a ``threading.Lock``
field, wrapped cache reads/writes and the warn-flag check-and-set in
``with self._lock:``. Stub fixtures unaffected.

**episode_clip_path is now a method of VideoFrameProvider.**
Used to be a free function reaching into ``provider._meta.episodes`` and
``provider._meta.get_video_file_path`` from outside the class. As a
method it just uses ``self._meta``. The only caller (Module 1) updated;
no external callers.

**Atomic write in LanguageColumnsWriter.**
``pq.write_table(new_table, path)`` was overwriting the parquet shard
in place — a crash mid-write would corrupt the file. Now writes to a
sibling ``.tmp`` and ``Path.replace`` atomically.

**Smaller items:**
* ``executor.py`` docstring opened with "four phases" but listed six.
  Now says "six phases" to match.
* ``[annotations]`` extra in ``pyproject.toml`` now includes
  ``openai>=1.40,<2.0``. Default ``VlmConfig.backend`` is ``"openai"``,
  so without it ``_make_openai_client`` would ImportError on a fresh
  ``uv sync --extra annotations``.
* ``_snap_to_frame`` was duplicated identically in
  ``plan_subtasks_memory.py`` and ``interjections_and_speech.py``.
  Promoted to ``snap_to_frame`` in ``reader.py`` (next to
  ``EpisodeRecord``); both modules now import it. Backwards-compat alias
  not needed — no external callers.
* ``EpisodeRecord.frames_df()`` was re-reading the full parquet on every
  call. Now memoizes via a private dataclass field so repeat calls from
  different modules pay the cost once. Method signature unchanged.
* ``_extract_first_json_object`` had a redundant ``and not escape`` guard
  that was dead because the prior block already handled and reset
  ``escape``. Replaced with a comment explaining the invariant.

**Pre-existing lint cleanups surfaced once these files entered
pre-commit's scope:**
* dead local ``client = clients[0]`` in ``_make_openai_client`` (the
  real round-robin uses ``clients[rr_counter[...]]``).
* ``cmd = ... if "{port}" in cmd else f"...{port}"`` ternary collapse in
  ``_spawn_parallel_inference_servers``.
* ``seek_pts = 0 if stream.time_base is None else int(...)`` ternary
  collapse in ``_decode_pyav_direct``.
* ``# nosec B310`` on the localhost ``urllib.request.urlopen`` probe in
  ``_server_is_up`` — the URL is the user-configured local-server endpoint
  the CLI itself spawned, not arbitrary user input.

**Test added.**
``tests/annotations/test_frames.py`` pins the regression on
``VideoFrameProvider``: asserts ``video_for_episode`` and
``episode_clip_path`` are callable methods (not nested dead code or
free functions), and that the ``_lock`` field is a real
``threading.Lock``.

Sweep: 64 passed, 2 failed (same pre-existing module-impl bugs as
before this commit). Pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 11:53:43 +02:00
Pepijn 088c8371df refactor(annotate): consolidate Module 1's prompt → VLM → JSON-extract pattern
Five Module 1 sub-prompts (`_derive_task_from_video`,
`_generate_task_rephrasings`, `_generate_subtasks`, `_generate_plan`,
`_generate_memory`) all repeated the same shape:

    result = self.vlm.generate_json([messages])[0]
    if isinstance(result, dict) and isinstance(result.get(<field>), <type>):
        ...

…each spelled with slightly different field names + post-processing.

Three small helpers replace it:

* `_vlm_field(messages, field)` — single VLM call, returns
  ``result[field]`` or ``None``. Centralizes the
  ``generate_json([m])[0]`` + ``isinstance(dict)`` dance.
* `_text_message(text)` — wraps a string in the canonical user-message
  shape every text-only prompt builds inline.
* `_video_message(record, prompt)` — combines the episode video block
  with a prompt; replaces the duplicated video-block construction
  inside `_generate_subtasks` (which previously inlined the same
  ``use_video_url``/``frames_per_second``/``max_video_frames`` branches
  that `_episode_video_block` already implements).

Net -35 LOC. Each call site now is 3-5 lines instead of 10-20. The
public method signatures are unchanged so tests don't move.

Drive-by: `_task_seems_bad` collapsed via SIM103 fix; `zip` in
`run_plan_updates` annotated `strict=True` per ruff B905.

Tests: same 2 pre-existing module-impl failures
(`test_module1_attaches_video_block_to_subtask_prompt`,
`test_module2_mid_episode_emits_paired_interjection_and_speech`) —
they were failing on `origin/feat/language-annotation-pipeline` before
this commit and continue to do so for the same reasons. 61/63 in the
language stack pass; pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 11:29:45 +02:00
Pepijn 3a52a18b0e Merge branch 'feat/language-columns' into feat/language-annotation-pipeline
Resolve conflicts and pull in the latest PR 1 fixes.

Conflicts:
- pyproject.toml: PR 1 added `lerobot-rollout` and PR 2 added
  `lerobot-annotate` to the same `[project.scripts]` block. Kept both.
- uv.lock: dropped both sides and regenerated against the merged
  `pyproject.toml` (PR 2 dropped the `datatrove` dep when distribution
  moved to HF Jobs; PR 1's lock didn't have it).

Test follow-up:
- `tests/annotations/test_pipeline_recipe_render.py` — PR 1 deleted
  `src/lerobot/configs/recipes/pi05_hirobot.yaml` (review feedback:
  remove the canonical-recipe file; recipes are user-supplied). The
  cross-PR contract this test guards is "the recipe DSL renders
  non-empty messages from pipeline output", which doesn't depend on
  any specific YAML, so the test now builds an inline blend recipe
  with the same coverage. Passes.

Sweep: 82 passed, 2 failed (pre-existing module-impl bugs:
`test_module1_attaches_video_block_to_subtask_prompt`,
`test_module2_mid_episode_emits_paired_interjection_and_speech`).
The PR 1 carryover (`test_emitted_at_raises_on_ambiguous_per_camera_vqa`)
is now passing — the merge brought in PR 1's tightened `_select_one`
ambiguity check.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 11:13:11 +02:00
Pepijn dad2cf1178 refactor(annotate): delegate distribution to HF Jobs; drop SLURM/local switch
The executor previously claimed it would "optionally hand off" to
datatrove's LocalPipelineExecutor or SlurmPipelineExecutor — but it
already runs phases inline in every code path, and HF Jobs (see
``examples/annotation/run_hf_job.py``) is the actual distribution
strategy. Stop pretending we have an executor selector.

* `executor.py`: drop `select_executor_class`, the "kind" log line, and
  the references to LocalPipelineExecutor / SlurmPipelineExecutor.
  Module docstring now says distribution is delegated to HF Jobs.
* `config.py`: drop `auto_threshold`, `force_local`, `slurm_partition`,
  `slurm_gpus`, `slurm_time`, `workers`. `ExecutorConfig` keeps only
  `episode_parallelism`. While here, prune the longer "why" docstrings
  on every field down to the load-bearing bits — full story moves to
  `docs/source/annotation_pipeline.mdx`.
* `pyproject.toml`: drop `datatrove>=0.4.0,<2.0.0` from the
  `[annotations]` extra; the dep was only there for the (never used)
  cluster executors. Comment block notes the new HF-Jobs delegation.
* `reader.py`, `lerobot_annotate.py`: drop their own datatrove /
  flavor-namespace mentions.
* `docs/source/annotation_pipeline.mdx`:
  - remove the flavor-namespace / sidecar paragraph (out of scope —
    "multiple revisions = multiple copies" is dataset-level policy);
  - remove the "writer drops the legacy `subtask_index` column" note
    (already covered by PR 1's intentional-break call-out);
  - remove the chat-template + `apply_chat_template(messages, tools=...)`
    line (covered by Tools doc);
  - replace the "executor picks Local vs Slurm" paragraph with
    `--executor.episode_parallelism` and a pointer to HF Jobs;
  - rewrite the style→recipe section to talk about "recipes" generically
    instead of pinning a specific YAML;
  - add a "Running on Hugging Face Jobs" section pointing at
    `examples/annotation/run_hf_job.py`;
  - add a "Running locally" example matching the CLI's docstring
    (`uv run lerobot-annotate --root=... --vlm.model_id=...`);
  - extend the paper-inspirations list with Pi0.7 and Steerable VLA
    Policies (Zhao 2025) for Module 3.

Tests: same 3 pre-existing failures as before this commit (2 module
assertions still in flight; 1 carryover from PR 1). 41/44 pass.
Pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 11:09:22 +02:00
Pepijn bce5387e04 Merge branch 'main' into feat/language-columns 2026-05-08 10:29:49 +02:00
Pepijn 85576acc29 docs(tools): drop follow-up-PR references
Reword the two callouts in `tools.mdx` to describe the runtime layer
in present tense ("not part of the catalog layer shipped today",
"those modules don't yet exist in the tree") instead of pointing at a
specific follow-up PR. Keeps the doc honest about what works now
without coupling it to a particular release order.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 20:29:42 +02:00
Pepijn e7e5fca5de review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction
* **Float tolerance in `emitted_at` for persistent styles.** The
  ``_timestamp(row) == t`` exact-equality check silently missed any
  caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``)
  even though the parquet timestamp would only differ by ULPs. Added
  ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance``
  instead, with a docstring explaining why exact equality wasn't
  enough and why 0.1 s is safe at typical 30–100 Hz control rates.
  Test asserts the new behavior at half-window (matches) and
  double-window (no match) using the constant so it stays in sync.

* **`MessageTurn.stream` is required at construction.** It was typed
  ``MessageStream | None = None`` so YAML could omit ``stream:`` and
  pass the dataclass invariant — but ``_validate_rendered`` rejected
  ``None`` streams later, surfacing the error at the first sample
  instead of at recipe load. Now ``__post_init__`` raises
  ``ValueError`` if ``stream`` is ``None``, with the list of valid
  streams in the message. The redundant late-stage check in
  ``_validate_rendered`` is replaced with a one-line comment that
  cites the upstream invariant. Test pins the new construction-time
  rejection.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 19:55:08 +02:00
Pepijn beb22afd81 review: dedupe regex, centralize column names, harden collate, more tests
* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
  `recipe.py` and `language_render.py`. Promote to module-level
  `PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
  template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
  hardcoded `{"language_persistent", "language_events"}` literals at
  two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
  rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
  silently filtered language fields from samples that didn't have
  them, which would hand downstream consumers a preserved list
  shorter than the tensor batch. Now: if any sample carries a key,
  every sample in the batch must carry it; otherwise raise a
  `ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
  or multi-element list fell through and triggered confusing
  `float([])` errors downstream. Now raises `ValueError` with the
  actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
  of `key = {... if ... else {}}` plus an 11-line splat dict with a
  single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
  comment. Add a docstring explaining it's an intentional extension
  point: downstream modules append project-local styles before
  `column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
  page referenced `src/lerobot/tools/`, `registry.py`, and
  `get_tools(meta)` — none exist in this PR. Added a callout at the
  start of "How to add your own tool" plus a note on the
  implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
  validation.** `test_recipe.py` grew from 1 case to 12 covering:
  blend-or-messages exclusivity, target-turn requirement, blend
  emptiness, weight presence/positivity, nested-blend rejection,
  `from_dict` with nested blends, `from_yaml` / `load_recipe`
  agreement, top-level non-mapping rejection. Added a malformed-row
  test for `_normalize_rows` that asserts non-dict entries raise
  `TypeError`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 19:06:38 +02:00
Pepijn 33a4b4a5a0 feat(smolvla2): autonomous robot mode in lerobot-smolvla2-runtime
The runtime CLI was deliberately scoped to dry-run only: it
hard-coded ``robot_executor=None`` and printed a "real-robot
integration is a follow-up" warning even when ``--no_robot`` was
omitted. The runtime *engine* was already structured for real-robot
operation (separate ``LowLevelForward`` chunk-rate generation +
``DispatchAction`` ctrl-rate dispatch with a ``robot_executor``
hook); only the wiring was missing.

Add the wiring:

  * ``_load_policy_and_preprocessor`` now also returns the
    postprocessor (action denormaliser).
  * ``--robot.type`` / ``--robot.port`` / ``--robot.id`` /
    ``--robot.cameras`` (JSON) build a ``Robot`` via
    ``make_robot_from_config`` and connect it.
  * ``_build_robot_observation_provider`` reads
    ``robot.get_observation()`` each call, drops the language
    columns (runtime drives messages itself), and runs the policy's
    preprocessor (rename → batch → device → normalise).
  * ``_build_robot_action_executor`` postprocesses the policy's
    action tensor (denormalise), converts to the ``{joint: value}``
    dict via ``make_robot_action(action, ds_meta.features)``, and
    calls ``robot.send_action(...)``. Optional ``--max_action_norm``
    safety clip rejects ticks whose action L2 norm exceeds the
    threshold (kill-switch when bringing up a new robot).
  * ``_run_autonomous`` runs ``runtime.run()`` in a background
    thread (the policy must keep generating chunks at chunk_hz and
    dispatching at ctrl_hz regardless of stdin) and handles user
    interjections / VQA queries from the foreground stdin loop.
    Confirmation prompt before start (skip with ``--auto_start``);
    Ctrl+C stops the thread and disconnects the robot cleanly.
  * Autonomous mode requires ``--dataset.repo_id`` for action stats
    / feature shapes — pass the same dataset the policy was trained
    on. The bootstrap path that pulls canonical task / plan / memory
    runs in both REPL and autonomous modes so the model's first
    prompt matches training distribution.

Dry-run REPL behaviour is unchanged when ``--robot.type`` is not
passed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 18:30:56 +02:00
Pepijn d55b581ca1 fix(language): address review — tools accessor, motion docs, conditional collate
* **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo`
  had no `tools` field, so `from_dict` silently dropped the key (it
  warned about unknown fields then discarded them) and the property
  always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None`
  to the dataclass; `to_dict()` drops it when unset so existing
  datasets keep a clean `info.json`. Fixed the accessor to read
  `self.info.tools` (the previous `.get(...)` would have raised
  AttributeError on the dataclass anyway). Added regression tests:
  fallback when absent, round-trip from disk, and round-trip
  through `DatasetInfo.from_dict` / `to_dict`.

* **`motion` is not view-dependent — fix the docs.** The mdx claimed
  rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES
  = {"vqa", "trace"}` and the validator agrees: motion primitives are
  joint/Cartesian-frame, not pixel-space. Updated both call-out
  paragraphs in `language_and_recipes.mdx`.

* **Conditional `collate_fn` swap.** Added `meta.has_language_columns`
  and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it,
  so non-language datasets keep PyTorch's `default_collate`. Also
  added a pass-through test in `test_collate.py` that asserts on a
  plain tensor batch the custom collate matches `default_collate`
  key-for-key, plus a test for the `None`-sample drop path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:51:06 +02:00
Pepijn 24d2ffe3c6 fix(language): keep base install green — drop processor re-export, gate dataset-extra tests
`lerobot.processor` re-exported `RenderMessagesStep` at the package
level, so importing anything from `lerobot.processor` pulled in
`lerobot.datasets.language` → `lerobot.datasets/__init__.py` →
`require_package("datasets")`, which fails in the Tier 1 base install
that intentionally omits the `[dataset]` extra. The chain bricked
collection for unrelated suites (`tests/policies/pi0_pi05/...`,
`tests/envs/...`, etc.).

* Stop re-exporting `RenderMessagesStep` from `lerobot.processor`. The
  only consumer (the test) already imports from the submodule.
  Document the deliberate omission in the module docstring.
* Add `pytest.importorskip("datasets", ...)` (and `pandas` where
  needed) at the top of the four PR-added tests that exercise the
  language stack:
  - tests/datasets/test_language.py
  - tests/datasets/test_language_render.py
  - tests/processor/test_render_messages_processor.py
  - tests/utils/test_collate.py

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:12:54 +02:00
Pepijn 789f29aa56 chore: fix CI — collapse short ValueError to one line, refresh uv.lock
* `ruff format` on CI (newer version) wants the short `camera=None`
  ValueError on a single line.
* `uv.lock` was stale relative to `pyproject.toml`'s `datasets>=4.7.0`
  pin (and picked up upstream `s390x` marker fixes for cuda packages).
  CI runs `uv sync --locked` which rejected the divergence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:05:42 +02:00
Pepijn a356b12c41 fix(language): always raise on ambiguous resolver matches
`_select_one` previously skipped its ambiguity check whenever any of
`role`/`tool_name`/`camera` was set, on the assumption that the caller
had already pinned down a unique row. That left a real ambiguity hole
for VQA: with two cameras emitting `(vqa, assistant)` at the same
frame, `emitted_at(..., role="assistant")` silently picked the first
sorted row instead of telling the recipe to add `camera=...`. The
existing `test_emitted_at_raises_on_ambiguous_per_camera_vqa` test
already encoded the desired behavior.

Tighten the check: any time `len(rows) > 1` we now raise with the
selectors echoed back, so users see exactly which fields they passed
and that more is needed to disambiguate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 14:00:45 +02:00
Pepijn e8327b8e62 refactor(language): unify resolver dispatch and prune redundant test scaffolding
* Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`;
  only `emitted_at` actually consults events. The dispatcher in
  `_resolve_spec` now passes events conditionally.
* Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a
  single `_row_sort_key` and drop the `sort_key` parameter from
  `_select_one`. Event rows lack `timestamp` (it is implicit in the
  frame) and now default to `0.0` for sort purposes — the
  `(style, role)` tiebreaker is unchanged.
* Inline `_select_latest` into `active_at` (its only caller).
* Collapse `emitted_at`'s dual-branch into one `_select_one` call.
* Tighten `_validate_persistent_resolver` to a single
  `column_for_style(style) != LANGUAGE_PERSISTENT` check.
* Parameterize `test_per_camera_blend_renders_both_views` over the two
  cameras and factor the sub-recipe builder into `_vqa_subrecipe` so
  the test no longer hand-rolls two near-identical recipe blocks.

Net -98 LOC; behavior, public resolver names, and test expectations
unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 13:15:45 +02:00
Pepijn c450298147 Apply ruff and prettier formatting after merge
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 12:10:41 +02:00
Pepijn 5c30b14929 Merge remote-tracking branch 'origin/main' into feat/language-columns 2026-05-06 12:09:13 +02:00
Pepijn a764c3e1d6 fix(datasets,annotate): tag pushed dataset + clean revision error
Two bugs combining to make the brand-new ``_tool3`` dataset
unloadable:

1. ``lerobot_annotate.py:_push_to_hub`` uploads the annotated
   dataset folder but never creates a codebase-version tag, so
   ``api/datasets/<repo>/refs`` returns ``"tags": []``. Then
   ``LeRobotDatasetMetadata`` → ``get_safe_version`` →
   ``get_repo_versions`` returns empty and the loader raises
   ``RevisionNotFoundError``.

2. ``RevisionNotFoundError`` itself was unconstructible: its
   ``HfHubHTTPError.__init__`` indexes ``response.headers``
   unconditionally on current ``huggingface_hub`` versions, so
   constructing it without a real ``Response`` blew up with
   ``AttributeError: 'NoneType' object has no attribute 'headers'``,
   masking the real "no tag" message.

Fix #1: after upload, read ``meta/info.json["codebase_version"]`` and
``HfApi.create_tag(..., tag=<v3.x>, repo_type='dataset',
exist_ok=True)`` so the dataset is loadable straight from the Hub on
the next ``LeRobotDataset(repo_id)`` call. Falls back to the in-tree
``CODEBASE_VERSION`` if info.json is missing/malformed; on tag
creation failure, prints the manual one-liner the user needs.

Fix #2: stop trying to instantiate ``RevisionNotFoundError`` (which
inherits HfHubHTTPError) for what is really a config issue, not an
HTTP failure. Raise plain ``RuntimeError`` with the same message —
the caller actually sees what's wrong instead of an upstream
attribute error.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 18:23:18 +02:00
Pepijn b416f287f2 fix(datasets): raise readable error when repo has no version tags
``RevisionNotFoundError`` inherits from
``huggingface_hub.HfHubHTTPError`` which made ``response`` a required
keyword-only argument on recent versions. Constructing it with just a
message string blew up with
``TypeError: HfHubHTTPError.__init__() missing 1 required keyword-only
argument: 'response'`` instead of surfacing the actual problem (the
dataset/checkpoint repo doesn't exist on the Hub yet).

Pass ``response=None`` explicitly. Fall back to the bare-message form
for older ``huggingface_hub`` versions that don't accept the kwarg.
Also clarify the message to call out the most common cause: typing a
hub repo id that hasn't been pushed yet (instead of just "needs a
version tag").

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 18:12:40 +02:00
Pepijn aa749d4947 chore(annotate): throttle Module 3 + executor parallelism to fix vLLM stall
Last bump combined ``module_3.K=3`` with ``vqa_emission_hz=2.0`` and
``executor.episode_parallelism=32``. With 2 cameras per dataset that
produced ~12× the original VQA call volume, all submitted concurrently.
Module 3 latency went from ~30s/phase to ~490s per episode, vLLM's
KV cache pegged at 94% with 800+ in-flight requests, and the
multimodal cache corrupted with ``AssertionError: Expected a cached
item for mm_hash='...'`` (a known vLLM bug under image-heavy
concurrency). Module 1 and 2 ran fine; Module 3 was the bottleneck.

Pull back the multipliers to land in a sustainable spot:

  * module_3.K: 3 (kept) — three diverse questions per emission,
    where the diversity actually helps the LM head.
  * module_3.vqa_emission_hz: 2.0 → 1.0 — back to the original
    emission rate. Net VQA volume is now ~3× original (K alone) on
    a single camera, ~6× across both cameras — manageable.
  * module_2.max_interjections_per_episode: 9 → 6 — still 2× the
    default, fewer than the prior 3× to keep total request volume
    in check.
  * vlm.client_concurrency: 256 → 128 — gives vLLM headroom on the
    multimodal request path so the mm_cache doesn't desync.
  * executor.episode_parallelism: 32 → 16 — half the episodes
    in flight at once, so peak vLLM load is ~half.

n_task_rephrasings stays at 30 (text-only, doesn't load the image
path) and vlm.temperature stays at 0.7. The diversity gains are
preserved; only the throughput knobs come down.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 15:07:18 +02:00
Pepijn 1394a6ab5d chore(annotate): bump diversity knobs ~3x to fight memorisation
Following Pi0.7 §V (prompt expansion / diverse context conditioning),
push more atom variants per episode and higher VLM sampling
temperature so the training distribution has enough wording diversity
that the LM head is forced to use its parameters rather than memorise
specific (prompt, target) pairs.

Changes vs prior annotation pass:

  * vlm.temperature: 0.2 (default) → 0.7 — every Module-1/2/3 call
    now produces diverse phrasings; same prompt yields different
    completions across emissions.
  * module_1.n_task_rephrasings: 10 → 30 — three times as many
    ``task_aug`` rows in language_persistent. ``${task}`` already
    rotates through them deterministically per sample_idx (see
    ``_resolve_task`` in language_render.py).
  * module_2.max_interjections_per_episode: 3 (default) → 9 — more
    ``user_interjection_response`` training samples + more plan
    refresh events.
  * module_3.K: 1 → 3 — three VQA pairs per emission tick instead of
    one. Combined with the hz bump below, ~6× more VQA samples.
  * module_3.vqa_emission_hz: 1.0 → 2.0 — double the VQA emission
    rate within each subtask span.

Pushes to a new hub repo (``_tool3``) so the working ``_tool2``
dataset stays intact for comparison. ``${task}`` already wired to
rotate through ``task_aug`` rows, so no renderer change needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 14:32:05 +02:00
Pepijn db9118f16f fix(smolvla2): reject gibberish high-level generations
Memorised models can collapse to dominant-mode outputs (the
JSON-token salad ``":":":":...`` from VQA training) when the prompt
drifts even slightly from training distribution. Without a guard,
that gibberish lands in ``current_subtask`` / ``current_plan`` /
``current_memory``, which feeds the next tick's prompt and cascades
into worse outputs. The user observed exactly this: a clean run
followed by a tick that wrote ``" " "`` into plan and memory, then
slow recovery several ticks later.

Add ``_looks_like_gibberish`` heuristic (alpha density, repeating
chars, JSON-prefix sniff) and apply it before mutating state in
``HighLevelSubtaskFwd`` / ``MemoryUpdateFwd`` / ``UserInterjectionFwd``.
Bad generations are logged inline (``[info] subtask gen rejected
(gibberish): "":":":..."``) so the user can see what was dropped, but
the state stays at its last-known-good value (typically the dataset
bootstrap) instead of being polluted.

VQA path is intentionally exempt — its training targets *are*
JSON-shaped, so the heuristic would false-positive on them.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 14:07:25 +02:00
Pepijn 7a945d7bdc fix(smolvla2): bootstrap canonical task + plan/memory from dataset
The user-typed task and the dataset's canonical task differ in
wording (capitalisation, ``green box`` vs ``green bin``, etc.). With
``text_loss`` driven down to ~6e-6 across 78 epochs the model is
memorised on the *exact* rendered training prompts: any wording drift
puts the prompt out of distribution and the model collapses to its
dominant training mode (VQA JSON output).

When ``--dataset.repo_id`` is set, automatically:
  * read the canonical task string from the chosen episode (and use
    it as ``--task`` when the user didn't pass one);
  * pull the active ``plan`` / ``memory`` / ``subtask`` rows from the
    persistent slice (latest row whose timestamp ≤ start frame's
    timestamp — same semantics as the renderer's ``active_at``) and
    seed them into the runtime state.

The first prompt the runtime builds at REPL start now mirrors what
the recipe rendered during training (task + active plan + active
memory + optional current subtask). The user can still override any
of these by typing.

Memorisation itself is upstream (training mix collapsed to too few
unique high-level targets); this commit only fixes the inference-side
prompt mismatch that was making the memorisation surface as gibberish.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 14:00:36 +02:00
Pepijn a47e535b02 fix(smolvla2): per-recipe inference prompts to match training shape
The four high-level steps shared one generic
``_control_context_messages`` that jammed task + plan + memory +
completed_subtask into a single user message. The recipes in
``smolvla2_hirobot.yaml`` each have a *specific* multi-message layout
(``memory_update``: ``user(task) → assistant(prev memory) →
user(completed subtask)``; ``high_level_subtask``: ``user(task+plan+
memory) → user(current subtask)``; ``user_interjection_response``:
``user(task) → assistant(prev plan) → user(interjection)``). After
``apply_chat_template`` those layouts produce different prompts than
the runtime's flattened single-user-turn version, and the model fell
back to its dominant training mode (VQA JSON output) — generating
``":":":":":":...`` repetition.

Add four per-recipe prompt builders (``_msgs_for_subtask``,
``_msgs_for_memory``, ``_msgs_for_interjection``, ``_msgs_for_vqa``),
each mirroring its sub-recipe's exact message structure including
the ``if_present`` skips. Wire each high-level step to its matching
builder. Inference prompts now line up with what the model saw in
training, so generation should produce coherent text instead of
repeated tokens.

Generic ``_control_context_messages`` is kept (still used by tests
and the no-recipe fallback path).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 13:47:22 +02:00
Pepijn 6d9b431b54 fix(smolvla2): match training's text-loss forward in select_message
Previous rewrite drove generation through ``vlm.generate()`` (the
standard SmolVLM path), which ignores SmolVLA's custom ``embed_prefix``
that interleaves images + lang + state. Result: the model received a
prompt format it had never been trained on at inference and emitted
JSON-fragment gibberish (``" " " ,",","`` ``cube lift {"...``).

Revert to the cumulative-buffer AR loop driven through
``vlm_with_expert.forward`` — the *same* forward call ``_compute_text_loss``
makes during training (``inputs_embeds=[prefix_embs, None],
use_cache=False, fill_kv_cache=True``). With ``fill_kv_cache=True``,
every layer routes through ``forward_attn_layer``, which gracefully
skips ``None`` expert inputs (``if hidden_states is None or layer is
None: continue``); cross-attention layers — which would otherwise hard-
require a non-None expert input — are bypassed entirely.

Inference now sees the same prefix structure as training: images +
lang + state, with new tokens appended to the lang region. The text
distribution matches what the model was trained to produce.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 13:42:15 +02:00
Pepijn 347e706326 fix(smolvla2): drop pixel_values from select_message generate path
SmolVLA's image preprocessor sizes frames to whatever the action
expert was trained on, but SmolVLM's standard vision tower expects
its own default tile grid (e.g. 384/14 → 27×27 patches). The
mismatch surfaces deep in the post-vision reshape as
``RuntimeError: shape '[2, 34, 34, 768]' is invalid for input of
size 1843200`` — the model has 1200 patches but expects 34×34=1156.

Drop ``pixel_values`` from ``vlm.generate(...)`` so SmolVLM runs as
a text-only LM at REPL time. The high-level branches (subtask /
plan / memory) are dominated by their text context anyway, so this
is acceptable for dry-run inference. VQA loses its image grounding
— that will be marked as expected for the dry-run path until a
follow-up either re-processes images through SmolVLM's own
``ImageProcessor`` to match its tile grid, or gives
``vlm_with_expert`` a real AR text decode mode that handles state
and image embeddings the way training does.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 13:36:53 +02:00
Pepijn fa8ae1e89b fix(smolvla2): drive select_message through SmolVLM.generate
The hand-rolled AR loop in ``select_message`` was fighting the
underlying ``vlm_with_expert.forward`` design, which assumes the
"prefix-once + suffix-always-via-expert" pattern that ``denoise_step``
uses for action chunks. Cross-attn layers (every other layer with
``attention_mode='cross_attn'`` + ``self_attn_every_n_layers=2``)
hard-require an expert input on every call: passing
``inputs_embeds=[current_embs, None]`` crashed at
``expert_layer.input_layernorm(None)`` with ``'NoneType' object has
no attribute 'dtype'``. Earlier KV-cache attempts ran into the
matching ``[15, 139] vs [15, 1]`` shape mismatch because the cache
gets *overwritten*, not appended, on each ``fill_kv_cache=True`` call
— there's just no AR-text-decode mode in this forward.

Stop fighting it: drive AR text generation through the underlying
SmolVLM via ``vlm.generate(input_ids=..., attention_mask=...,
pixel_values=...)``. KV caching, sampling/greedy, EOS handling all
come from HF's standard implementation. Trade-off: ``state`` drops
out of the prefix at inference (no slot for it on the standard
SmolVLM path), so high-level generations may drift from training
distribution slightly. That's acceptable for the dry-run REPL — the
high-level branches (subtask / plan / memory / vqa) are mostly
vision+language conditioned anyway, and the action expert (where
state actually matters) goes through the unchanged ``select_action``
path.

Image features the runtime merged in (``observation.images.*``) are
stacked into the ``[B, num_images, C, H, W]`` shape SmolVLM expects.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 12:39:34 +02:00
Pepijn 3ff6c6860e fix(smolvla2): rewrite select_message decode loop without KV cache
SmolVLA's ``vlm_with_expert.forward`` doesn't actually support
incremental KV cache growth — its only ``fill_kv_cache=True`` mode
*overwrites* the cache with the latest call's key/value states, and
its only ``fill_kv_cache=False`` mode concatenates ``cache + new``
into a local ``key_states`` for one matmul without ever updating the
cache itself. The original ``select_message`` decode loop tried to
use ``fill_kv_cache=True`` per step, which clobbered the cache to
1 token after the first decode and threw
``Expected size for first two dimensions of batch2 tensor to be:
[15, 139] but got: [15, 1]`` — the attention mask still expected
139 keys but the cached + new key_states only had 1.

Match the pattern ``denoise_step`` already uses successfully:
maintain a cumulative ``(embs, pad, att)`` buffer that starts as the
prefix and grows by one bool/embedding row per step. Each step
forwards the *full* sequence with ``use_cache=False,
fill_kv_cache=False, past_key_values=None`` so the matmul shapes
always line up. Generated-token rows are tagged ``pad=1, att=1``
which makes them fully causal among themselves while still able to
attend back to the entire prefix (per ``make_att_2d_masks``
semantics: a token can attend to any earlier token whose cumulative
``att`` count is ≤ its own).

Image encoding is still done once via the initial ``embed_prefix``
call — the expensive part doesn't repeat. The remaining cost is
O(n²) text-only transformer forwards, which is fine for the dry-run
REPL's 50–100 token responses.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 12:15:28 +02:00
Pepijn fd89efb545 fix(smolvla2): 3D attention mask in select_message decode loop
SmolVLA's ``eager_attention_forward`` does
``masked = torch.where(attention_mask[:, None, :, :], ...)``, which
requires a 3D ``[B, query_len, key_len]`` bool tensor so the
broadcast to 4D works. ``select_message``'s prefix forward got this
right (passes ``prefix_2d`` from ``make_att_2d_masks``), but the
KV-cache decoding loop built ``new_attn = torch.ones((bsize,
cur_pos + 1))`` — 2D — and the very first decode step blew up with
``IndexError: too many indices for tensor of dimension 2``.

During KV-cache decoding ``query_len = 1`` and
``key_len = cur_pos + 1`` (prefix + every token already generated),
so the right shape is ``[B, 1, cur_pos + 1]``. Match the layout
SmolVLA's working ``denoise_step`` uses for the equivalent
``prefix_pad_2d_masks`` build.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 12:08:52 +02:00
Pepijn 2776b57c9e fix(smolvla2): bool attention mask + clean Claude-Code-style REPL
Two issues that combined to make the REPL unusable:

1. ``BatchEncoding.attention_mask`` is a ``Long`` tensor, but SmolVLA's
   ``eager_attention_forward`` does
   ``torch.where(attention_mask[..., None, :, :], ...)`` which
   requires a *bool* condition. Every forward raised ``where expected
   condition to be a boolean tensor, but got a tensor with dtype Long``
   and the diagnostic surfaced it cleanly in the REPL — but generation
   produced nothing useful. Cast to ``bool`` in ``_build_text_batch``
   so the prefix forward goes through.

2. The interactive REPL used ``rich.live.Live`` panels stacked on top
   of ``logging.basicConfig(level=DEBUG)`` HTTP request lines from
   ``httpcore`` / ``httpx`` / ``huggingface_hub``. The two rendering
   loops fought each other in the user's terminal and the output was
   illegible: hundreds of debug lines interleaved with re-rendered
   panels.

   Replace ``Live`` with a simple block redraw — clear screen, print
   the state block, print any robot log lines, then a single ``> ``
   prompt. State changes are visible above the prompt, the way Claude
   Code's REPL renders. No flicker, no re-render races.

   ``_silence_noisy_loggers`` drops the chatty third-party HTTP /
   download / model-init loggers to WARNING. ``-v`` still enables
   DEBUG on the lerobot loggers; if the user needs the HTTP traces,
   they can flip those individually.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 12:03:47 +02:00
Pepijn 0fb5f04965 fix(smolvla2): handle BatchEncoding return from apply_chat_template
``tokenizer.apply_chat_template(..., tokenize=True, return_tensors='pt')``
on newer transformers returns a ``BatchEncoding`` (dict-like) rather
than a raw ``Tensor`` — particularly when the underlying call routes
through a processor. ``_build_text_batch`` only handled the ``Tensor``
and ``list`` shapes, so the encoding object reached SmolVLA's
``embed_language_tokens`` and ``F.embedding`` blew up with
``argument 'indices' must be Tensor, not BatchEncoding`` on every
high-level forward.

Normalise the return:
  * ``BatchEncoding`` / ``dict`` → take ``input_ids`` (and the encoder's
    ``attention_mask`` when present, since ``pad_token_id`` can be
    ``None`` for SmolVLM and the fall-back ``ids != pad_token_id``
    breaks then),
  * ``list[int]`` / ``list[list[int]]`` → wrap in a long tensor,
  * ``Tensor`` → keep as-is.

After unwrapping, ensure shape ``(1, seq)`` and that ``attention_mask``
is a tensor on the same device as ``ids``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 11:59:57 +02:00
Pepijn 7296ac97af fix(smolvla2): make silent generation failures visible in REPL
Two failure modes were combining to make the runtime "look dead":

1. ``_build_text_batch`` produced lang tokens via
   ``apply_chat_template(return_tensors='pt')`` on CPU, but the policy
   sits on the configured device (mps / cuda). The first prefix-embed
   inside ``select_message`` then raised a device-mismatch on every
   call. The bare ``except Exception`` in ``_generate_with_policy``
   swallowed it at debug level — no logs, no chat output, no visible
   sign anything had run.

2. Even when generation succeeded but returned an empty string
   (greedy EOS, unhappy chat template, etc.), the high-level steps
   silently no-op'd, so users saw nothing.

Move tokens to ``policy.config.device`` in ``_build_text_batch`` so
the prefix forward succeeds in the common case. Bump the swallowing
log level to ``warning`` (with optional traceback under ``-v``), and
when ``state`` is given route the same diagnostic into the REPL log
via ``push_log`` so the user sees ``[warn] subtask gen failed: ...``
inline. Also push an ``[info] ... produced no text this tick`` line
when generation runs but yields nothing, so empty completions are
distinguishable from "step never ran". Apply the same surface to
``LowLevelForward.select_action`` failures.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 11:47:34 +02:00
Pepijn 9cbbcfb6a2 fix(smolvla2): tokenize lang prompt inline before select_action
LowLevelForward was handing the observation provider's output straight
to ``policy.select_action``, but SmolVLA's ``_get_action_chunk``
indexes ``batch[OBS_LANGUAGE_TOKENS]`` and crashes with ``KeyError:
'observation.language.tokens'`` when the key isn't there. Our provider
deliberately strips the dataset's language columns (the runtime drives
messages itself), so nothing else was producing those tokens — the
chunk path crashed on the very first tick after task was set.

Build a low-level prompt from current runtime state inline (task /
plan / memory as the user turn, current subtask appended as a
continuation assistant turn when known), tokenize it with the same
helper the high-level steps use, and merge ``lang_tokens`` /
``lang_masks`` into the observation before the call. Skip the step
when no task is set yet, and swallow ``select_action`` exceptions at
debug level so a missing observation feature doesn't kill the REPL.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 11:40:18 +02:00
Pepijn fea41b29f5 fix(datasets): probe parquet for language columns before strict cast
``_load_hf_dataset`` was building the strict cast schema only from
``meta/info.json["features"]``. Datasets annotated by
``lerobot-annotate`` but still tagged at the older codebase version
(no ``language_persistent`` / ``language_events`` entry in
``info.json``) carry both columns in the parquet itself but not in the
features dict, so ``Dataset.from_parquet`` blew up with
``CastError: column names don't match`` when trying to project a
9-column parquet onto a 7-column schema.

Probe one parquet shard's actual schema; if either language column is
present in the parquet but missing from ``features``, graft it on
using PR 1's ``language_persistent_column_feature`` /
``language_events_column_feature`` helpers. No-op when neither column
is present (fully backwards-compatible with v3.0 datasets), no-op when
both are already registered (fully forwards-compatible with future
v3.1 ``info.json`` writes).

This unblocks dry-run inference on PR 2-annotated datasets that
weren't re-tagged to v3.1 — including the ones in the field today.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 11:31:19 +02:00
Pepijn 7b4d281ef5 fix(smolvla2): build preprocessor fresh, don't round-trip the recipe
``PolicyProcessorPipeline.from_pretrained`` reconstructs each saved
step by passing the persisted JSON config back to ``__init__``, but
``RenderMessagesStep.recipe`` (a ``TrainingRecipe``) doesn't survive
the JSON round-trip — the saved entry is ``{}`` and the reconstructor
crashes with ``missing 1 required argument: 'recipe'``.

Bypass the round-trip in the runtime CLI by passing
``pretrained_path=None`` to ``make_pre_post_processors``. That re-runs
``make_smolvla2_pre_post_processors``, which reloads the recipe YAML
referenced by ``cfg.recipe_path`` and wires it back into the step
correctly. ``NormalizerProcessorStep`` still gets stats from
``ds_meta.stats`` so normalization matches training.

Proper fix is to make ``RenderMessagesStep`` serializable (e.g. by
persisting the recipe path / contents); this commit keeps it scoped to
the runtime path so dry-run testing isn't blocked.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 11:27:12 +02:00
Pepijn 29bb8bb20e fix(tools): unblock pocket-tts resolution (>=1.0.0,<3.0.0)
The previous bound `>=0.1.0,<1.0.0` matched zero published versions —
pocket-tts went straight to 1.0.0 on PyPI, with 0.x never released.
That made `uv sync --extra tools` (and any sync that pulls the `dev` /
`all` superset) fail with "requirements are unsatisfiable" on every
Python version uv tried, including 3.12.

Bump to `>=1.0.0,<3.0.0` so 1.x and 2.x are reachable. SayTool only
touches `TTSModel.load_model()`, `get_state_for_audio_prompt`,
`generate_audio`, and `sample_rate` — small enough surface that 1.x
and 2.x should both work; tighten if a real API break shows up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 11:15:20 +02:00
Pepijn 3fe686ce9f feat(smolvla2): runtime accepts Hub IDs + dataset-driven dry-run
The runtime CLI's loader was broken — it imported a `make_policy_from_path`
that doesn't exist in `lerobot.policies.factory` — and the high-level text
steps generated plan / subtask / memory / VQA from a text-only batch with
no images or state, so dry-runs drifted from the training distribution.

Switch to the standard `PreTrainedConfig.from_pretrained` +
`make_policy(cfg, ds_meta=...)` flow so `--policy.path` accepts both local
directories and Hub repo ids, and add a `--dataset.repo_id` path that walks
a chosen episode and feeds preprocessed observations into every forward
pass — including the four high-level steps (`HighLevelSubtaskFwd`,
`MemoryUpdateFwd`, `UserInterjectionFwd`, `AskVQAFwd`). Frames are routed
through the saved preprocessor pipeline with `language_persistent` /
`language_events` stripped so the recipe-render step stays a no-op (the
runtime supplies its own messages from current state).

Also wires the rich-based two-zone REPL layout (`ui.py`) that the script
was already importing.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 11:09:19 +02:00
pepijn a1b8134ef1 fix(smolvla2): train on rendered language batches
Keep annotated language columns through collation, render batched recipe samples, and make SmolVLA2 text loss robust enough for distributed training on the steerable dataset.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-05 08:55:56 +00:00
pepijn 8fa8323c91 fix(annotate): sync language metadata after parquet rewrite
Ensure annotated datasets advertise language columns in meta/info.json so non-streaming dataset loads cast against the rewritten parquet schema.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-04 15:17:15 +00:00
Pepijn 5f7c6ba61d feat(annotate): compact steerable annotation prompts
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-04 15:57:04 +02:00
Pepijn 223cc8a9e2 feat(smolvla2): inference runtime — select_message + multi-rate REPL
Closes the loop on PR 3: SmolVLA2 can now be queried interactively at
inference, dispatching the same five sub-recipe shapes it was trained
on (action chunks, subtask gen, memory updates, plan/speech on
interjection, VQA on questions).

Modeling fixes + additions
--------------------------

- ``_compute_text_loss``: standard next-token CE shift was missing
  (logits at position t were CE'd against the label at t — identity-
  mapped, learning nothing). Adds ``logits[:, :-1]`` /
  ``labels[:, 1:]`` shift to match HuggingFace ``LlamaForCausalLM``.

- New ``select_message`` on ``SmolVLA2Policy``: AR text generation
  with KV caching, mirroring SmolVLA's ``select_action`` pattern.
  Single prefix forward fills the cache, then per-token forwards
  reuse it. Greedy + top-p nucleus sampling. Returns the decoded
  string with the prompt stripped.

Runtime package — ``src/lerobot/policies/smolvla2/inference/``
-------------------------------------------------------------

- ``triggers.py`` — ``Trigger`` Protocol + ``HzTrigger`` /
  ``EventTrigger`` + ``TickClock``. The whole runtime ticks at
  ``max_rate_hz=50`` and each step gates itself off its own
  cadence.

- ``runtime_state.py`` — runtime state dict factory plus tiny
  helpers (``take_event``, ``set_if_changed``, ``push_log``).
  Stable keys are documented at the top of the module.

- ``steps.py`` — :class:`InferenceStep` base + concrete steps:
  ``LowLevelForward`` / ``DispatchAction`` (action path),
  ``HighLevelSubtaskFwd`` / ``MemoryUpdateFwd`` /
  ``UserInterjectionFwd`` / ``AskVQAFwd`` (text paths),
  ``DispatchToolCalls`` (tool registry → ``Tool.call``). Each
  text step builds a chat-template prompt from current
  ``RuntimeState`` (task / plan / memory / subtask) matching
  what ``smolvla2_hirobot.yaml`` renders during training.
  Includes a tiny ``<say>...</say>`` parser for the
  ``user_interjection_response`` branch's combined plan + speech
  output.

- ``runtime.py`` — :class:`SmolVLA2Runtime` composes the pipeline,
  drives ticks via ``TickClock``, polls a user-supplied
  ``event_collector`` per tick, and prints state-change log lines.

- ``repl.py`` — :class:`StdinReader` non-blocking line reader
  with simple intent classification: ``stop`` / ``quit`` /
  ``exit`` → terminate; ``?`` suffix → ``user_vqa_query`` event;
  first line → set task; other lines → ``user_interjection``.

CLI
---

- ``src/lerobot/scripts/lerobot_smolvla2_runtime.py``: console
  script ``lerobot-smolvla2-runtime`` that loads a checkpoint,
  optionally instantiates ``SayTool`` (pocket-tts), wires up
  ``SmolVLA2Runtime`` + ``StdinReader``, and runs.

  Real-robot wiring (observation_provider / robot_executor) is
  intentionally left as a follow-up — v1 is dry-run / language-
  only so the REPL works without robot hardware.

  Registered in ``pyproject.toml`` ``[project.scripts]``.

Known follow-ups
----------------

- Real-robot integration: today ``LowLevelForward`` only fires when
  an observation_provider is wired. The CLI prints a warning if
  ``--no_robot`` is omitted.
- ``select_message`` runs an extra prefix forward; could share with
  the action path's prefix when both are needed in the same tick.
- Tests: no end-to-end runtime test yet (would need a tiny SmolVLM
  fixture). The components compile and the public surface is
  exercised by the CLI's argument-parsing path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 22:04:00 +02:00
Pepijn af6d8ebd5b feat(smolvla2): dual-head forward — flow loss + lm_head text loss
The third and final commit of PR 3's SmolVLA2 work. Wires the actual
training signal through:

* ``predict_actions[i] = True``  → sample i contributes to flow loss
* ``text_labels[i, t] != -100``  → token t of sample i contributes to
                                    LM-head cross-entropy

Both routing knobs come from ``SmolVLA2ChatTokenizerStep`` (previous
commit on this branch), which builds them from the recipe's
``message_streams`` / ``target_message_indices``. The per-sample
``predict_actions`` mask preserves the Pi0.5 convention from the
plan's Section I.7: "True iff any low_level target exists".

Implementation:

- ``forward`` reads ``text_labels`` and ``predict_actions`` from the
  batch. When neither is present (vanilla SmolVLA usage with no
  recipe), delegates to ``SmolVLAPolicy.forward`` so unannotated
  datasets keep training as before — full backward compatibility.
- ``flow_loss``: super().forward(reduction="none") returns the
  per-sample (B,) flow loss; we mask non-action samples with the
  ``predict_actions`` bool and renormalize by the count of action
  samples. ``flow_loss_weight = 0`` in the config disables this
  branch entirely (text-only training).
- ``text_loss``: a prefix-only forward through the VLM (no action
  expert / suffix), slicing the lang-token range out of the
  resulting hidden states (``embed_prefix`` orders the prefix as
  ``[image_blocks..., lang, state]`` so the slice is unambiguous).
  Apply ``vlm.lm_head`` to those hidden states, cross-entropy with
  ``text_labels`` (ignore_index=-100). ``text_loss_weight = 0``
  disables this branch (reverts to flow-only behaviour, matching
  SmolVLA exactly).
- The two losses are summed with the config-supplied weights.

Mixed-stream samples (one batch containing both action targets and
text-only sub-recipes) are handled correctly: each sample contributes
where its labels are valid and is masked elsewhere.

Limitations / known follow-ups:

- Text loss runs an additional prefix-only forward separate from the
  flow path's prefix forward. The forwards could share their prefix
  computation; for clarity of this first commit they don't.
  Optimization is straightforward when needed.
- Per-sample loss for ``reduction="none"`` is not yet meaningfully
  defined for the dual path — we broadcast the scalar to (B,) for
  caller compatibility (e.g. RA-BC weighting will need follow-up).
- Inference ``select_action`` is unchanged from SmolVLA today —
  it predicts actions only. A separate "generate text"
  ``select_message`` path is the natural next step for runtime
  use of the LM head (memory updates, plan refreshes, VQA answers).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 19:54:57 +02:00
Pepijn 37b1eb218a feat(smolvla2): chat-template processor + label mask + predict_actions
Wires PR 1's recipe stack into the SmolVLA2 pipeline so multi-target
sub-recipes (memory_update, ask_vqa, user_interjection_response,
high_level_subtask) carry meaningful supervision through to the model.

- New ``chat_processor_smolvla2.py`` with
  ``SmolVLA2ChatTokenizerStep``: reads ``messages`` /
  ``message_streams`` / ``target_message_indices`` from the rendered
  sample (PR 1 ``RenderMessagesStep``), calls
  ``apply_chat_template(messages, tools=DEFAULT_TOOLS, ...)`` on the
  SmolVLM tokenizer, and writes:

    OBS_LANGUAGE_TOKENS / _ATTENTION_MASK   ← chat-templated prompt
    text_labels                              ← -100 except target msg tokens
    predict_actions                          ← True iff any low_level target

  Builds the label mask robustly by re-rendering the chat through
  each target's prefix and reading off the prefix length — same
  tokenizer, same tools, so the prefix tokens are guaranteed to be
  a prefix of the full sequence. Image/video content blocks
  (LeRobot ``feature``-keyed) are stripped before tokenizing; the
  actual image tensors flow through SmolVLA's existing
  ``OBS_IMAGES_*`` channels and ``embed_prefix`` puts them before
  the language embeddings, matching the chat-template-stripped
  text order.

- ``processor_smolvla2.py``: when ``config.recipe_path`` is set,
  build a new pipeline with ``RenderMessagesStep`` +
  ``SmolVLA2ChatTokenizerStep`` instead of SmolVLA's plain
  ``TokenizerProcessorStep``. When ``recipe_path`` is ``None``,
  fall back to SmolVLA's pipeline so unannotated datasets still
  work unchanged. Resolves recipe paths relative to
  ``src/lerobot/configs/`` so ``recipes/smolvla2_hirobot.yaml``
  works directly.

The next commit on this branch picks up ``text_labels`` and
``predict_actions`` from the batch and routes them through the
SmolVLM ``lm_head`` for the actual dual-loss training.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 19:21:03 +02:00
Pepijn 52e1fd35cb feat(tools): src/lerobot/tools/ — runnable tool registry + SayTool
Ships the runtime side of the OpenAI-style function-calling stack
introduced in PR 1 (catalog in ``meta/info.json["tools"]``) and PR 2
(annotation pipeline writes the catalog after a run). One file per
tool — heavy deps stay isolated.

Layout:

- ``base.py`` — :class:`Tool` Protocol: ``name``, ``schema``,
  ``call(arguments)``. Runtime-checkable so tests can use
  ``isinstance(...)``.
- ``registry.py`` — :data:`TOOL_REGISTRY` (name → class) plus
  ``get_tools(meta, **kwargs)`` that instantiates every entry whose
  ``function.name`` is registered. Tools whose name is unknown are
  silently skipped — the schema still rides through the chat
  template, the model just can't actually invoke that tool at
  inference.
- ``say.py`` — :class:`SayTool` wrapping Kyutai's pocket-tts
  (CPU-only, ~100M params, ~6× real-time on a MacBook Air M4).
  Lazy model load: pocket-tts is imported and the voice state
  computed on first ``call(...)`` (or eagerly via ``preload()``).
  Returns the PCM tensor; optionally writes a ``.wav`` to
  ``output_dir`` for offline inspection.
- ``__init__.py`` — re-exports the public surface.

Optional install:

    pip install lerobot[tools]

The ``[tools]`` extra in ``pyproject.toml`` pulls in ``pocket-tts`` +
``scipy`` (for the wav writer). Adding more tools later means a new
file + a registry entry — no new extras unless the tool brings new
deps.

To add your own tool, follow the three-step guide in
``docs/source/tools.mdx`` (PR 1):

  1. Drop ``src/lerobot/tools/<my_tool>.py`` with a ``Tool``-conforming
     class.
  2. Register the class in ``TOOL_REGISTRY`` (this file).
  3. Pre-populate ``meta/info.json["tools"]`` with the schema (or let
     ``lerobot-annotate`` add it on the next run).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:58:04 +02:00
Pepijn 7459dfccb6 feat(policies): scaffold smolvla2 (smolvla + lm_head re-enabled)
PR 3 of the steerable-annotation plan retargeted from Pi0.5 to SmolVLA
because the recipe stack (PR 1 + PR 2) outputs HF/TRL-compatible chat
which a chat-pretrained backbone consumes natively. SmolVLA strips the
SmolVLM ``lm_head`` though, so it can only do flow-matching action
prediction. SmolVLA2 keeps the LM head so the same model can train on
the full Hi Robot / MEM / ECoT blend defined in the plan:

  * action-only sub-recipes  (low_level_execution)        flow loss
  * text-only sub-recipes    (memory_update / ask_vqa /   CE loss on
                              user_interjection_response)  lm_head
  * mixed sub-recipes                                      both summed

This first commit lays down the structural scaffold:

- ``src/lerobot/policies/smolvla2/`` — new package with thin subclasses
  of ``SmolVLAConfig`` / ``SmolVLAPolicy`` so we don't fork the 900-line
  modeling code. ``SmolVLA2Config`` adds ``recipe_path``,
  ``apply_chat_template``, ``text_loss_weight``, ``flow_loss_weight``,
  and ``unfreeze_lm_head``. ``SmolVLA2Policy`` unfreezes the SmolVLM
  ``lm_head`` (and the surrounding norm + last text-model layer SmolVLA
  freezes) when ``unfreeze_lm_head=True`` and ``text_loss_weight>0``.
- ``factory.py`` registers ``smolvla2`` in ``get_policy_class``,
  ``make_policy_config``, and the pre/post-processor builder. Important:
  the ``smolvla2`` branch lives BEFORE the ``isinstance(config,
  SmolVLAConfig)`` check because ``SmolVLA2Config`` subclasses
  ``SmolVLAConfig`` — without the ordering, SmolVLA2 would silently
  pick up SmolVLA's processor.
- ``configs/recipes/smolvla2_hirobot.yaml`` — canonical Hi Robot blend
  for SmolVLA2. Same shape as ``pi05_hirobot.yaml`` (PR 1) so the
  recipe stack stays uniform across policy backbones.

Behaviour today is identical to SmolVLA: the modeling forward
delegates to ``SmolVLAPolicy.forward`` and the processor delegates to
``make_smolvla_pre_post_processors``. The next commit on this branch
adds the chat-template processor + ``text_labels`` / ``predict_actions``
batch keys; the commit after that wires the actual text-loss path
through ``vlm.lm_head``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:55:23 +02:00
Pepijn 73740ecf4b feat(annotate): write tool catalog to meta/info.json after annotation
After every ``lerobot-annotate`` run, the executor ensures
``meta/info.json["tools"]`` contains at minimum the canonical ``say``
schema, while preserving any tools the user pre-declared on the
dataset. Chat-template consumers (PR 3 SmolVLA2 / Pi0.5 / dataset
visualizer) read the catalog through
``LeRobotDatasetMetadata.tools`` and pass it to
``apply_chat_template(messages, tools=meta.tools, ...)``.

- ``executor.py``: new ``_ensure_tools_in_info`` helper called
  after the parquet rewrite. Idempotent and additive — merges by
  ``function.name``, only writes back if the list changed.
- ``writer.py``: drops the duplicated ``SAY_TOOL_SCHEMA`` /
  ``DEFAULT_TOOLS`` constants in favour of importing from
  ``lerobot.datasets.language`` (PR 1's single source of truth).
  Re-exported so existing imports keep working.
- ``annotation_pipeline.mdx``: replace the "code constant only" note
  with a pointer to the new Tools doc and a description of the
  meta/info.json behaviour, including how to pre-declare custom
  tools before annotation runs.

This is the storage half of the tools work; PR 3 ships the runnable
implementations under ``src/lerobot/tools/`` (one file per tool,
first up: ``say.py`` wired to Kyutai's pocket-tts).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:51:38 +02:00
Pepijn 1b81e49214 feat(annotate): task rephrasings + video-derived task fallback
Module 1 now produces ``task_aug`` rows (registered in PR 1) so the
PR-1 ``${task}`` resolver can rotate phrasings deterministically per
``sample_idx``. Plus an opt-in video-derived task that bypasses the
canonical ``meta/tasks.parquet`` task when it's empty, low-quality, or
explicitly disabled — every downstream Module-1 prompt then uses the
derived task as its grounding.

- ``Module1Config``: adds ``n_task_rephrasings`` (default 10) and
  ``derive_task_from_video`` ∈ ``{off, if_short, always}`` (default
  ``if_short``: triggers when canonical is empty, < 3 words, or matches
  a placeholder string like ``debug`` / ``unnamed`` / ``tbd``).
- ``plan_subtasks_memory.py``: ``run_episode`` now resolves an
  ``effective_task`` (canonical OR video-derived) and threads it
  through ``_generate_subtasks`` / ``_generate_plan`` /
  ``_generate_memory`` so subtasks, plans, and memory are all grounded
  in the same task string. Then generates ``n`` rephrasings of the
  effective task and writes them as ``task_aug`` rows at ``t=0`` with
  ``role=user``. The effective task itself is included as the first
  variant so the rotation is guaranteed to cover the source-of-truth
  phrasing.
- New prompts: ``module_1_video_task.txt`` (one-shot video → task),
  ``module_1_task_rephrasings.txt`` (text-only paraphraser, ``n`` per
  call).
- ``meta/tasks.parquet`` is NOT modified — derived tasks live only in
  ``language_persistent``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn d813c75b76 fix(annotate): align interjections with the actual demo trajectory
qwen36moe-11 surfaced a deeper semantic problem with mid-episode
interjections: they were generated as *counterfactual* user requests
("actually skip the wipe", "use the blue one instead") but teleop data
is frozen — the robot in the video already executed everything,
including the steps the user "asked to skip". The training signal was
therefore self-contradictory: interjection text said one thing, the
robot's subsequent action stream did the opposite.

Flip the framing. Anchor every interjection at a subtask boundary and
write it as a natural user request for the *upcoming* subtask. The
robot's visible next behavior IS the interjection's effect, so:

  interjection text → plan refresh → action stream

are all consistent with the same observed video.

Concretely:

- ``interjections_and_speech.py``: instead of sampling random
  timestamps from ``frame_timestamps``, walk Module 1's subtask spans
  and sample from the (subtask N → subtask N+1) transitions. Pass both
  the just-finished and the upcoming subtask texts into the prompt.

- ``_window_timestamps``: re-center the multi-frame video window on
  the boundary itself (half the frames cover the end of the previous
  subtask, half cover the start of the next one) so the VLM has the
  same visual conditioning the policy will see at training time.

- ``module_2_interjection.txt``: rewritten. The prompt now states
  explicitly that this is offline data, the robot already committed to
  the next subtask, and the interjection must be a natural request
  that aligns with — not contradicts — the next subtask. Removes the
  "negative task / situated correction" Hi Robot framing because those
  scenarios require online execution to be coherent.

Plan-refresh logic from the previous commit (forwarding interjection
text into the refresh prompt) is unchanged and now reinforces the same
direction: the refreshed plan emphasizes the upcoming subtask the
interjection just asked for.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn 3434d2ef22 fix(annotate): ground interjections in video + propagate text to plan refresh
qwen36moe-10 showed three Module-2 / plan-refresh quality issues that
are not architecture problems — they're prompt-grounding bugs:

1. Interjection prompt passed ``current_subtask = record.episode_task``
   (the WHOLE-episode task), not the actual subtask in force at the
   chosen timestamp. The VLM had no signal about what was visible at
   that moment, so its interjections were generic ("actually skip X"
   where X had nothing to do with the visible activity).

2. Interjection prompt only attached a single frame
   (``frames_at(record, [t_snap])``). With one frozen image the VLM
   couldn't read the ongoing motion. Module 1 already gets the whole
   episode video for subtask decomposition, which is why subtasks are
   well-grounded; Module 2 was the outlier.

3. The plan-refresh prompt told the model "a plan refresh after a user
   interjection at t=X.YZs" but never showed it the interjection
   *text*. So the refreshed plan couldn't actually reflect the user's
   correction — at best it recombined the same step list.

Fix:

- ``interjections_and_speech.py``: Module 2 reads Module 1's subtask
  rows from the same staging tree (executor orders module_1 → module_2
  so they're already there) and resolves the actual ``current_subtask``
  at each chosen timestamp. Pulls a small clip
  (``interjection_window_seconds`` × ``interjection_window_frames``,
  defaulting to 4 frames over the leading 2 s) instead of one frame.
  Drops the silently-zeroing ``len(candidate_ts) // 4`` cap on the
  interjection count.

- ``module_2_interjection.txt``: prompt is rewritten to reference the
  multi-frame visual context and require the interjection to mention
  something visible OR named in the current subtask, not invented.

- ``plan_subtasks_memory.py``: ``run_plan_updates`` now accepts and
  threads through interjection texts. ``_generate_plan(refresh_t,
  interjection)`` injects both the current subtask AND the interjection
  text into the prompt so the refreshed plan can drop / reorder /
  constrain steps to match the user's correction. (Plan still refreshes
  ONLY at user interjections — subtask generation runs ~1 Hz at
  inference, plan re-emission is event-driven.)

- ``executor.py``: forwards ``interjection_texts`` alongside
  ``interjection_times`` to ``run_plan_updates``.

- ``Module2Config``: bumps ``max_interjections_per_episode`` default
  from 1 to 3 and exposes ``interjection_window_seconds`` /
  ``interjection_window_frames``.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn b71e10da6b refactor(annotate): drop dataset-level `tools` parquet column
PR 2 used to write a top-level ``tools`` column on every parquet shard
holding the JSON schema for the ``say`` tool, broadcast identically
across every row. That extends PR 1's schema for no real information
gain — the schema is a fixed code constant, parquet's RLE/dict encoding
collapses it on disk anyway, and HF/TRL chat-template consumers can
just import the constant directly.

PR 2 should fill in PR 1's existing schema, not add to it. So:

- ``writer.py``: stop emitting the ``tools`` column. Strip any legacy
  ``tools`` column from older shards on rerun so the schema converges to
  v3.1. ``SAY_TOOL_SCHEMA`` stays as a public constant (now joined by
  ``DEFAULT_TOOLS = [SAY_TOOL_SCHEMA]``); chat-template policies and the
  visualizer import them directly.
- ``test_writer.py``: replace the "tools column present" assertion with
  one that explicitly checks the column is absent, plus a new test
  asserting the constant's shape.
- ``test_pipeline_recipe_render.py``: drop the tools-column read; assert
  it's not present in the rewritten parquet.
- ``annotation_pipeline.mdx``: update the writer description to note the
  parquet stays small and the schema lives as a code constant.

If multi-tool-set support ever becomes real (datasets with different
tool inventories), the right home is ``meta/info.json["tools"]`` —
adding it later is non-breaking; ripping out a parquet column already
shipped is not.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn 0f6e3230df fix(annotate): decode video frames with PyAV directly
``lerobot.datasets.video_utils.decode_video_frames`` routes
``backend="pyav"`` through ``decode_video_frames_torchvision`` →
``torchvision.io.VideoReader``, but ``VideoReader`` was removed in
torchvision >= 0.22 (the vllm/vllm-openai:latest container ships with
torchvision 0.25). That made every Module 3 frame decode raise
``AttributeError: module 'torchvision.io' has no attribute 'VideoReader'``,
which the previous catch-all silently turned into an empty image list,
which then made every Module 3 prompt skip via the
``not _has_image_block(messages)`` branch and produce zero VQA rows.

Bypass ``video_utils`` entirely. The annotation pipeline only needs
a handful of PIL frames per (episode, ts), so a direct PyAV decode is
both simpler and insulated from torchvision API churn. ``av`` is already
in the install set, no new dependency.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn 2f2e42c4aa log(annotate): warn loudly on first video decode failure
VideoFrameProvider._decode used to swallow every exception silently and
return []. That made Module 3 (VQA) produce zero rows whenever local
video decoding broke (codec, backend, missing file, ...) because every
prompt got skipped via the ``not _has_image_block(messages)`` branch in
general_vqa.py — without any signal in the job log.

Log the first failure with full exception info (subsequent failures
stay quiet to avoid log spam) so this fast-path is debuggable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn 5ee0104739 log(annotate): surface resolved frame-provider cameras at startup
Print the default and full camera list once at the top of every run so a
silent Module-3-no-op (cam_keys=[]) is visible in the job log instead of
only being discoverable by counting parquet rows after upload.

Also warn loudly when Module 3 is enabled but no cameras resolved, with
a hint about the --vlm.camera_key fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn e064cfcb04 fix(annotate): seed Module 3 cameras from camera_keys + camera_key fallback
Module 3 fast-pathed out (50 episodes in 0.6s) when
``frame_provider.camera_keys`` came back empty even though Module 1/2
worked, because they use ``frame_provider.camera_key`` (singular) and
were happy with the explicit ``--vlm.camera_key=...`` override.

Two fixes:

- ``frames.py``: read ``meta.camera_keys`` (covers both video- and
  image-stored cameras) instead of ``meta.video_keys`` (video-only),
  matching :class:`LeRobotDatasetMetadata`'s canonical accessor. If
  metadata still surfaces nothing but the caller explicitly passed
  ``--vlm.camera_key=<key>``, fall back to ``[<key>]`` — the key is by
  definition known to exist on the dataset.
- ``general_vqa.py``: emit a one-time WARNING log when Module 3 sees
  zero cameras so this never silently produces zero VQA again.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn b3d9494831 docs(annotate): add HF Jobs runner example for lerobot-annotate
A ready-to-run example of launching the annotation pipeline on a
Hugging Face job (h200x2) with two vllm replicas serving
Qwen3.6-35B-A3B-FP8. Lives next to other end-to-end recipes under
examples/.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn 1217fdb6f0 feat(annotate): emit VQA per-camera and propagate camera field
Module 3 now produces one (vqa, user) + (vqa, assistant) pair per
emission tick *per camera* rather than only against the dataset's first
camera. Each emitted row carries the `camera` field added in PR 1
(language-columns), so the resolver can disambiguate per-camera VQA via
`emitted_at(t, style=vqa, role=assistant, camera=...)` without ambiguity.

- `frames.py`: `FrameProvider` Protocol gains a `camera_keys` property
  and a `camera_key=` argument on `frames_at` / `video_for_episode`.
  `VideoFrameProvider` exposes every `observation.images.*` key the
  dataset declares (not just the first) and keys its decode cache on
  `(episode, camera, timestamp)` so per-camera reads don't collide.
  Module 1 / 2 keep their old single-camera behaviour by leaving
  `camera_key=None` (falls back to the default camera).
- `modules/general_vqa.py`: `run_episode` iterates `frame_provider
  .camera_keys` for each emission tick, builds one prompt per camera,
  batches all of them through the VLM, and stamps the resulting rows
  with `camera=<that key>`. Empty `camera_keys` (null provider) makes
  the module a no-op rather than silently emitting untagged rows.
- `writer.py`: `_normalize_persistent_row` / `_normalize_event_row`
  carry `camera` through and call `validate_camera_field` so the
  invariant is enforced at the writer boundary. Event sort key now
  includes `camera` for deterministic ordering when several cameras
  share `(timestamp, style, role)`. `speech_atom` sets `camera=None`.
- `validator.py`: `StagingValidator` gains a `dataset_camera_keys`
  field; `_check_camera_field` enforces the invariant and cross-checks
  every view-dependent row's `camera` against the dataset's known video
  keys. New `_check_vqa_uniqueness_per_frame_camera` flags duplicate
  `(vqa, role)` pairs at the same `(t, camera)`.
- `lerobot_annotate.py`: passes the live frame provider's
  `camera_keys` into the validator so the cross-check uses the actual
  dataset camera set.
- Tests: `_StubFrameProvider` exposes `camera_keys` and accepts the new
  `camera_key=` kwarg. `test_module3_vqa_unique_per_frame_and_camera`
  configures two cameras and asserts both are represented, that every
  emitted row has a `camera` tag, and that uniqueness holds per
  `(timestamp, camera, role)`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn d0388e1142 fix(annotate): transcode subclips to H.264 instead of stream-copy
Modern LeRobot datasets store videos in AV1, which vllm's libav build
cannot decode (the video processor returns 0 frames and downstream
chokes with ZeroDivisionError). Re-encode each per-episode subclip
with libx264 (preset ultrafast, crf 23) so the resulting mp4 is
universally decodable. Strip audio with -an for a smaller payload.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:36 +02:00
Pepijn 524aa59faa feat(annotate): pack multiple vllm replicas per GPU via num_gpus
Adds VlmConfig.num_gpus so parallel_servers can exceed the physical
GPU count. Replicas are round-robin-assigned to GPUs (e.g.
parallel_servers=4 + num_gpus=2 → replicas pinned to GPUs 0,1,0,1).
Backward-compatible: num_gpus=0 keeps the existing 1-replica-per-GPU
behavior.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 27f7829b09 feat(annotate): forward chat_template_kwargs to OpenAI extra_body
Lets callers pass per-request template flags such as
{"enable_thinking": false} for Qwen3.5/Qwen3.6 models, where the
default thinking preamble otherwise consumes the entire max_new_tokens
budget before any JSON is emitted.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 7f8bf108e8 fix(annotate): include prompt .txt files in wheel
The setuptools package-data declaration only listed envs/*.json, so
pip-installed wheels (including HF Jobs runs) were missing the
module_1_subtasks/plan/memory and module_2/3 prompt templates,
causing FileNotFoundError at runtime.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 855ff027f8 refactor(annotate): drop HF Inference Providers code path
Default backend is now a local OpenAI-compatible server (vllm /
transformers) which auto_serve spawns. Removes the
use_hf_inference_providers config flag and the router.huggingface.co
routing branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 3b797bb118 feat(annotate): --vlm.push_to_hub uploads the annotated dataset
After the pipeline completes, optionally create/locate a dataset repo
and upload the dataset root (excluding .annotate_staging/). Add
push_private and push_commit_message knobs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn aea04721ae feat(annotate): parallelize episodes within each module phase
Saturates parallel_servers + client_concurrency. Previously the
executor processed one episode at a time, so each Module 1 episode's
3-5 dependent VLM calls hit a single server with the others idle. Now
defaults to 16 episodes in flight; configurable via
ExecutorConfig.episode_parallelism.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn ab5479129a fix(annotate): probe /v1/models for spawn-helper readiness
vllm with --uvicorn-log-level warning suppresses the "Uvicorn running"
banner that the readiness watcher waited for, so the spawn helper hung
forever even after the API was live. Add an HTTP probe in parallel with
the log watcher and broaden the log markers to include vllm's own
"Starting vLLM API server" / "Available routes are" lines.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn e6d4ac6f02 fix(annotate): lock-protect per-line writes for parallel server streams
8 server-streaming threads writing chars unsynchronized cause UTF-8
sequences from different servers to interleave mid-byte, garbling the
terminal output. Switch to line-buffered reads with a single shared
print lock — output stays readable, ready-marker detection still works
on the line containing 'Uvicorn running' / 'Application startup
complete'.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 5722d365c5 feat(annotate): client_concurrency for parallel in-flight requests
Adds vlm.client_concurrency (default 16) which uses a ThreadPoolExecutor
to fan out batched chat.completions calls. vllm batches them internally
on the server side, giving big throughput wins on a single TP=1 server
without needing DP/TP and the NCCL setup it requires.

Module 3 now batches all per-episode VQA calls into a single
generate_json invocation so they fire in parallel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 3d7e60cee4 feat(annotate): parallel_servers spawns N independent vllm replicas
Adds --vlm.parallel_servers=N. Spawns N independent vllm processes
(each pinned to GPU i via CUDA_VISIBLE_DEVICES, listening on
serve_port+i) and round-robins requests across them. Sidesteps DP/TP
NCCL setup failures on nodes with restricted P2P/SHM.

Default serve_command for parallel mode: vllm serve <model_id>
--tensor-parallel-size 1 --max-model-len 32768 --uvicorn-log-level
warning. Override via --vlm.serve_command (use {port} placeholder).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 7b767d4d60 feat(annotate): per-episode progress logs in executor 2026-04-30 18:48:35 +02:00
Pepijn f1e3ab7794 fix(annotate): don't crash pipeline on persistent JSON parse failure
Some prompts/models occasionally return pure prose with no JSON object
even on retry. Returning None (and logging a preview) lets the pipeline
skip that one VLM call cleanly instead of aborting the whole episode.
The modules already check for None / non-dict results and degrade
gracefully (no row emitted from that call).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 585341ba9f fix(annotate): robust JSON extraction (think tags + first balanced object)
Models often wrap JSON in prose or <think>...</think> blocks. Strip the
think tags first, then try direct json.loads, then fall back to scanning
for the first balanced {...} substring (ignoring braces inside strings).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn 23ff346027 fix(annotate): stream child stdout char-by-char so tqdm \\r progress flushes 2026-04-30 18:48:35 +02:00
Pepijn 3c5cbe7af4 test(annotate): adjust video-block test for fps-based frame sampling 2026-04-30 18:48:35 +02:00
Pepijn f2cbd97635 feat(annotate): Module 1 samples image frames at fps rate
Replace the fixed max_video_frames count with a rate (default 1 fps).
A 30 s episode now sends 30 frames; a 5 s episode sends 5; capped at
max_video_frames (default 128) to avoid blowing up the payload on long
episodes.

Override with --module_1.frames_per_second=2.0 for denser sampling, or
--module_1.frames_per_second=0.5 for sparser.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn c06c8d594a feat(annotate): use cached HF token from huggingface-cli login
Fall back to huggingface_hub.get_token() when HF_TOKEN/HUGGINGFACE_API_KEY
env vars aren't set. That picks up the token cached by
'huggingface-cli login' so users don't need to export it on every shell.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:35 +02:00
Pepijn cd495a3a9d feat(annotate): default to HF Inference Providers, no local GPU needed
Flip the default backend to 'openai' with use_hf_inference_providers=True
and a Qwen3-VL-30B-A3B-Instruct:novita default model_id. The CLI now
runs end-to-end without a local model load — annotations are produced
by sending video_url + prompt to https://router.huggingface.co/v1.

Switch back to local inference with --vlm.backend=vllm or
--vlm.use_hf_inference_providers=false.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn c99ac45cd1 feat(annotate): one-flag HF Inference Providers backend
Setting --vlm.use_hf_inference_providers=true routes requests through
https://router.huggingface.co/v1 using HF_TOKEN as the API key, and
disables auto_serve so no local server is spawned. Combine with a
provider-pinned model id like 'Qwen/Qwen3-VL-30B-A3B-Instruct:novita'
or any plain model id to let HF route.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 13aaafeae0 fix(annotate): omit mm_processor_kwargs by default; transformers serve rejects it
transformers serve returns HTTP 422 'Unexpected fields' when
mm_processor_kwargs is in extra_body — that field is vllm-specific.
Drop it by default; opt in via LEROBOT_OPENAI_SEND_MM_KWARGS=1 when
talking to vllm serve.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 2129648bf4 fix(annotate): mm_processor_kwargs in extra_body; inline file URLs as data URLs
Two fixes for video_url with transformers serve:
- fps must be in extra_body.mm_processor_kwargs, not in the content
  block; otherwise the server discards it as unknown kwargs.
- file:// URLs aren't fetched by transformers serve. Read the local mp4
  and inline it as a base64 data:video/mp4 URL so the server sees the
  bytes directly.

Both surface as std::bad_alloc on the server side when wrong, which is
unhelpful but explains what we hit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn f5cd3f6e4e fix(annotate): detect server ready via stdout banner, not /v1/models polls
transformers serve rescans the HF cache on every /v1/models request
which exceeds the 2s urllib timeout, leaving the probe loop spinning
even after Uvicorn is fully up. Watch the streamed server output for
'Uvicorn running' / 'Application startup complete' instead.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn ecf5766301 fix(annotate): visible auto_serve via stdout prints + live server log stream
The previous logger-based output never appeared, leaving users in the
dark when auto_serve silently no-op'd. Switch to print(flush=True) so
the spawn decision is unmistakable, and stream the server's stdout to
the parent terminal in real-time on a background thread so model-load
progress and errors surface immediately.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 11597d4f71 fix(annotate): auto_serve defaults to True; probe before spawning
Default auto_serve to True so lerobot-annotate can drive the entire
flow with one command. Probe api_base/models first — if a server is
already reachable (user started one manually, or it's a remote
endpoint), skip the spawn.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 8b9c598cf4 feat(annotate): auto_serve mode spawns and tears down inference server
Setting --vlm.auto_serve=true with --vlm.backend=openai makes the CLI
launch 'transformers serve <model_id> --port <serve_port>
--continuous-batching' as a child process, poll /v1/models until ready
(up to serve_ready_timeout_s), run the pipeline, then SIGINT the
server on process exit.

Override the spawn command with --vlm.serve_command='vllm serve ...'
or any OpenAI-compatible launcher.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn b325475b38 feat(annotate): video_url block for openai backend
Module 1 can now send the episode's actual mp4 file as a video_url
content block instead of pre-decoded frames. The server (transformers
serve / vllm serve / ktransformers serve) handles frame sampling at
the configured fps. Default fps=1 (one frame per second is enough for
subtask-boundary detection on manipulation episodes).

A per-episode subclip is extracted to <root>/.annotate_staging/.video_clips/
via ffmpeg stream-copy (no re-encode) so the model sees only this
episode's frames, not the whole shard.

Enable with --module_1.use_video_url=true (and --vlm.backend=openai).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn ef137ff86a feat(annotate): openai-compatible backend for transformers/ktransformers serve
Adds a third backend that talks to any OpenAI-compatible server. This
unblocks Qwen3.6 (and other models) that work in transformers serve /
ktransformers but not in vllm 0.10.2's fallback path:

- launch the server out-of-process (transformers serve, vllm serve,
  ktransformers serve)
- point lerobot-annotate at it via --vlm.backend=openai
  --vlm.api_base=http://localhost:8000/v1 --vlm.model_id=...

Image and video blocks are converted to OpenAI image_url/video_url
data URLs automatically.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn c5df821a96 fix(annotate): use vllm.chat() API for multimodal prompts
vllm.generate() expects a string/TextPrompt; passing message dicts
fails. vllm.chat() applies the chat template and extracts image/video
blocks automatically, which is what we need for VL models.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 7ec3d7999c fix(annotate): drop guided_decoding=dict (api differs across vllm)
vllm 0.10.2 expects guided_decoding to be a GuidedDecodingParams object,
not a dict. Different vllm versions differ here. The parser already has
a one-retry JSON-recovery path, so drop guided decoding entirely for
portability.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 712d63abbd fix(annotate): tolerate decoder returning fewer frames than requested
pyav (and sometimes torchcodec) decode can return fewer frames than
requested timestamps when some timestamps fall outside the video file's
content range. Drop the strict=True on the zip and rely on the
None-filter to discard missing frames.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 6653999983 fix(annotate): default video decode backend to pyav
torchcodec's __init__ bad-allocs on the cu128/torch-2.8 stack in some
environments (Lustre/conda combos). The annotation pipeline calls
decode_video_frames many times per episode, so this is a hard blocker.
Default to pyav (always available via the av package) and let users
opt back into torchcodec via LEROBOT_VIDEO_BACKEND=torchcodec.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn 4bdbedc9a0 fix(annotate): default trust_remote_code=False for HF loaders
Setting trust_remote_code=True unconditionally pulled custom loader
code that triggers std::bad_alloc post-load on Qwen3-VL — the official
transformers class is sufficient. Flip the default to False; keep the
config field so users can opt in for models that actually need it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn e240305e8e fix(annotate): default transformers backend to manual GPU placement
Loading Qwen3-VL via transformers + accelerate's device_map='auto'
fails with std::bad_alloc on hosts with abundant RAM. The bug is in
accelerate's post-load dispatch path. Bypassing accelerate by loading
to CPU first and then calling .to('cuda') manually avoids that path.

LEROBOT_TRANSFORMERS_DEVICE_MAP=auto switches back to the old behavior
for cases where it works.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn ccd189b264 fix(annotate): LEROBOT_DISABLE_CUDNN escape hatch for conv3d crash
cuDNN 9.x + torch 2.8 has a regression where the conv3d kernel used in
Qwen-VL vision tower patch embedders fails with
CUDNN_STATUS_NOT_INITIALIZED. The crash is independent of model size
and reproduces on both Qwen2.5-VL and Qwen3-VL because both use 3D conv
for video patch embedding.

Setting LEROBOT_DISABLE_CUDNN=1 falls back to native PyTorch conv3d
kernels (slower but functional) so the pipeline can run while the
torch/cuDNN stack is still on the broken combo.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:34 +02:00
Pepijn ef1242bbd4 fix(annotate): expose gpu_memory_utilization and max_model_len for vllm
Large VL models (Qwen3-VL-30B-A3B BF16) take ~58 GB of an 80 GB H100,
leaving only ~22 GB for KV cache + cuDNN workspace. The vision tower's
3D conv then fails with CUDNN_STATUS_NOT_INITIALIZED because cuDNN
can't grab a workspace large enough.

- vlm.gpu_memory_utilization (default 0.9) — drop to 0.7 when the vision
  encoder needs more cuDNN workspace.
- vlm.max_model_len — cap context to free KV cache memory; the 262k
  default for Qwen3 is wildly more than annotation prompts need.
- vlm.trust_remote_code — already plumbed; now also passed to LLM().

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn ebf4a04d41 fix(annotate): pass trust_remote_code=True to HF auto-classes
Required for many newer VL checkpoints (Qwen3.x FP8 in particular) that
ship custom loader code in their repo. Without it, the FP8
weight_scale_inv parameters never bind to FP8Linear modules and the
post-load dispatch path bad-allocs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn 4419b4ef1b fix(annotate): low_cpu_mem_usage=True on transformers load path
The std::bad_alloc we hit on Qwen3-line VL models is not a real OOM —
it triggers in the post-load tensor-placement path even on hosts with
2 TB RAM. low_cpu_mem_usage=True bypasses the offending intermediate
staging buffer and is the standard accelerate workaround.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn ff06ca82d2 fix(annotate): use device_map='auto' for transformers backend
Without device_map, transformers stages the full FP8 checkpoint in CPU
RAM before any GPU placement, OOMing the host on 27B+ models even when
the GPU has enough VRAM. device_map='auto' streams shards directly to
GPU memory.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn fcb01e73eb fix(annotate): try AutoModelForImageTextToText first, fall back to AutoModelForVision2Seq
Newer transformers versions renamed/removed AutoModelForVision2Seq in
favour of AutoModelForImageTextToText for VL models. Try the new name
first and fall back gracefully so the transformers backend works on
both transformers 4.45-4.5x and 5.x.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn 268f8d1f53 fix(annotate): replace Literal types with str for older draccus
Older draccus versions (e.g. 0.10.x bundled in some envs) lack a decoder
for typing.Literal and raise:
  No decoding function for type typing.Literal['vllm', 'transformers', 'stub']

Switching VlmConfig.backend from Literal to str works under every
draccus version. The runtime branch in vlm_client.make_vlm_client
already validates the value and raises ValueError on unknown backends,
so the constraint stays enforced.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn 663fff0ae2 feat(annotate): Module 1 sees the whole episode as one video block
Replaces keyframe sampling with a single Qwen-VL video block covering
the whole demonstration. The model pools temporally itself and chooses
where to cut subtasks — no stride, no count, no keyframe count knob to
tune.

- frames.py: ``FrameProvider`` gains ``video_for_episode(record,
  max_frames)``; ``VideoFrameProvider`` samples up to ``max_frames``
  uniformly across the episode duration; ``_NullProvider`` returns []
  for the no-video fallback. New ``to_video_block`` helper.
- Module 1: drops keyframe sampling. The subtask prompt now goes out as
  ``[{"type":"video", "video":[<frames>]}, {"type":"text", ...}]`` and
  the prompt template asks the model to "watch the whole clip, then
  segment it" with cut points decided from gripper/contact/regrasp
  events the model sees.
- Module1Config: ``keyframes_per_episode`` removed; replaced with
  ``max_video_frames: int = 32`` (model-capacity bound, not annotation
  logic).
- Test: ``test_module1_attaches_video_block_to_subtask_prompt`` locks in
  the single-video-block invariant.
- Stub-VLM markers updated: tests now key on "atomic subtasks" instead
  of the old "Decompose the demonstration" phrase that no longer
  appears in the prompt.
- Docs: updated to describe the whole-episode video-block behavior and
  the no-video fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn 9d6af804bf feat(annotate): attach camera keyframes to module prompts; default to Qwen3.6-27B-FP8
Closes the visual-grounding gap flagged after the initial PR review:
modules now decode actual camera frames at the relevant timestamps and
attach them as `{"type":"image", "image":<PIL>}` content blocks to the
VLM prompts.

- New `frames.py`:
  - `FrameProvider` Protocol; `VideoFrameProvider` decodes from the
    dataset's first `observation.images.*` stream via
    `LeRobotDatasetMetadata.get_video_file_path` and
    `decode_video_frames`, with the same `from_timestamp` shift the main
    dataset uses.
  - Per-process LRU cache so co-timestamped Module 1 plan-update + Module
    2 calls share decode work.
  - `make_frame_provider` falls back to a null provider when the dataset
    has no video tracks → text-only prompts (graceful absence).
- Modules 1/2/3 take an optional `frame_provider` (default null) and
  prepend image blocks before the text block.
  - Module 1 attaches `keyframes_per_episode` keyframes to the subtask
    decomposition prompt.
  - Module 2 attaches the frame at the interjection timestamp.
  - Module 3 attaches the exact emission frame to each VQA pair.
- VlmConfig: backend now defaults to `vllm`; default model is
  `Qwen/Qwen3.6-27B-FP8`. New knobs: `--vlm.tensor_parallel_size`,
  `--vlm.camera_key` (override the keyframe stream).
- `_make_vllm_client` honours `tensor_parallel_size` so 27B-FP8 sharded
  on 2× GPUs works out of the box.
- `test_module3_attaches_frame_image_block_to_prompt` asserts modules
  emit one image block per VQA prompt at the exact emission timestamp.
- Docs: example switched to `imstevenpmwork/super_poulain_draft` +
  Qwen3.6-27B-FP8 + tensor_parallel_size=2; documents the keyframe
  attachment behaviour and the no-video fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn f763f85213 feat: language annotation pipeline (PR 2/3)
Adds the steerable annotation pipeline (`lerobot-annotate`) that populates
the `language_persistent` and `language_events` columns introduced in
PR 1 directly into `data/chunk-*/file-*.parquet`. No flavor namespace,
no sidecar tree.

Modules produced:
- Module 1 (plan_subtasks_memory): Pi0.7-style subtasks, plan (init +
  refresh on interjection), MEM-style memory at subtask boundaries.
- Module 2 (interjections_and_speech): t=0 speech-only acknowledgement,
  mid-episode paired interjection + speech tool-call atom.
- Module 3 (general_vqa): bbox/keypoint/count/attribute/spatial pairs at
  configurable cadence with one-retry JSON validation.

Writer enforces: per-episode persistent identity, exact-frame event
timestamps, column routing per `column_for_style`, dataset-level `tools`
column with the `say` schema, drops legacy `subtask_index`. Validator
runs against staged JSONL artifacts before the writer rewrites parquet.

Adds `lerobot-annotate` console script, `annotations` extra (datatrove +
optional vllm), `make annotation-e2e` opt-in smoke target, and
`docs/source/annotation_pipeline.mdx`.

Branched from PR 1 (`feat/language-columns`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:48:33 +02:00
Pepijn e3e9374e2c feat(language): tool catalog in meta/info.json + LeRobotDatasetMetadata.tools
Stores OpenAI-style function schemas at ``meta/info.json["tools"]`` so
datasets can declare which tools are available (today: just ``say``;
tomorrow: per-dataset extensions). The ``DEFAULT_TOOLS`` constant
fills in for unannotated datasets so chat-template consumers don't
have to special-case anything.

Three pieces:

- ``language.py``: ``SAY_TOOL_SCHEMA`` and ``DEFAULT_TOOLS``
  constants. Single source of truth — PR 2's writer and PR 3's
  runtime tool registry will both import from here instead of
  duplicating the dict.
- ``dataset_metadata.py``: ``LeRobotDatasetMetadata.tools`` property
  reads ``info.json["tools"]`` and falls back to ``DEFAULT_TOOLS``.
  Returns deep-copied dicts so callers can mutate the result safely.
- ``docs/source/tools.mdx``: spec page covering the catalog, per-row
  invocations, and the three-step "how to add a new tool" workflow
  (declare schema, implement, register). Linked from the docs
  toctree under the Datasets section.

This lays the groundwork for PR 2's pipeline writing the catalog out
during annotation, and PR 3's ``src/lerobot/tools/`` package shipping
runnable implementations (one file per tool — first up:
``say.py`` wrapping Kyutai's pocket-tts).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:44:58 +02:00
Pepijn c1a0c601e2 feat(language): task_aug style + automatic ${task} rephrasing rotation
Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.

- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
  ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
  ``language_persistent`` so PR 2 writers route it correctly.

- ``language_render.py``: ``_resolve_task`` now consults the persistent
  slice for rows of ``style="task_aug", role="user"``. When any exist
  it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
  Python's randomized hash) so an epoch sees every rephrasing of every
  episode while the same sample still resolves identically across
  reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
  no rephrasings are present, so existing datasets and unannotated runs
  keep their behaviour. Explicit ``task=`` overrides still win.

- Tests: rephrasing coverage across samples, determinism on repeat
  ``sample_idx``, fallback when persistent has no ``task_aug`` rows,
  and explicit override priority.

Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 16:45:39 +02:00
Pepijn 1ca38d9748 fix(language): drop motion from VIEW_DEPENDENT_STYLES
Motion primitives are described in robot-frame (joint / Cartesian) terms,
not pixel space, so they are camera-agnostic. Only `vqa` (event) and
`trace` (event, pixel-trajectory) are view-dependent.

The `camera` field stays on PERSISTENT_ROW_FIELDS for schema symmetry —
the validator, resolver, and HF feature mapping behave identically across
the two columns regardless of which styles populate `camera` today —
but persistent rows now always have `camera=None` in practice.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 10:54:12 +02:00
Pepijn 5a6aa64570 feat(language): per-camera tagging on view-dependent styles
Adds a nullable `camera` field to the language row struct (both persistent
and event variants) so view-dependent styles like `vqa` can carry which
`observation.images.*` view they were grounded against. Without this,
multi-camera datasets ended up with multiple `(vqa, role)` rows at the
same timestamp that the resolver could not disambiguate.

- `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS,
  to both Arrow struct types and the HF datasets feature mappings;
  introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus
  `is_view_dependent_style` and `validate_camera_field` helpers (camera
  required iff style is view-dependent).
- `language_render.py`: thread an optional `camera=` kwarg through every
  resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through
  `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera
  VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`.
  Without a `camera` filter, multi-row matches keep raising the existing
  ambiguity error — which is the desired behaviour on multi-camera data.
- `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with
  `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying
  the matching image block), keeping the original 0.20 budget and
  documenting the customization point for datasets with different cameras.
- Tests: schema test asserts the new field order; new tests cover
  `is_view_dependent_style`, `validate_camera_field` (both required and
  forbidden directions), per-camera `emitted_at` filtering, and the
  ambiguity error when two cameras emit `(vqa, assistant)` at the same
  timestamp without a `camera=` filter. RenderMessagesStep + dataset
  passthrough fixtures updated to include the new field.
- `docs/source/language_and_recipes.mdx`: document the `camera` field,
  the per-camera resolver pattern, and the canonical recipe convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 10:48:17 +02:00
Pepijn 0b06790da0 feat(language): add motion (persistent) and trace (event-only) styles
Promote the previously-reserved motion/trace styles to first-class core
styles. motion routes to language_persistent (it tracks robot state over
time); trace routes to language_events (single-moment annotations).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:21:49 +02:00
Pepijn b43dc39ba4 Add docstrings to all new helpers; revert uv.lock
Covers private helpers in recipe.py, language.py, language_render.py,
and render_messages_processor.py. Also reverts uv.lock to main (it was
re-generated by `uv run` during local checks).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:15:03 +02:00
Pepijn 2b71221194 Address review: split persistent/event schemas, drop event timestamps
- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals
- dataset_metadata.py: keep CODEBASE_VERSION at v3.0
- language.py: remove RESERVED_STYLES; split arrow/feature schemas into
  persistent (with timestamp) and event (without timestamp); add docstrings
- language_render.py: events use frame-row timestamp implicitly; no
  per-event timestamp filtering or sorting
- converters.py: drop unused subtask_key passthrough
- add docstrings to new public APIs (recipe, render_messages_processor, collate)
- update tests for split schemas; revert uv.lock

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 13:38:23 +02:00
Pepijn 8833d735a1 Add extensive language support 2026-04-27 10:56:32 +02:00
101 changed files with 11671 additions and 2280 deletions
+3 -3
View File
@@ -167,9 +167,9 @@ jobs:
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
# immediately runs eval inside the training loop (env_eval_freq=1, 1 episode).
# immediately runs eval inside the training loop (eval_freq=1, 1 episode).
# Tests the full train→eval-within-training pipeline end-to-end.
- name: Run Libero train+eval smoke (1 step, env_eval_freq=1)
- name: Run Libero train+eval smoke (1 step, eval_freq=1)
if: env.HF_USER_TOKEN != ''
run: |
docker run --name libero-train-smoke --gpus all \
@@ -196,7 +196,7 @@ jobs:
--output_dir=/tmp/train-smoke \
--steps=1 \
--batch_size=1 \
--env_eval_freq=1 \
--eval_freq=1 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--eval.use_async_envs=false \
-3
View File
@@ -65,9 +65,6 @@ repos:
name: Format Markdown with Prettier
types_or: [markdown, mdx]
args: [--prose-wrap=preserve]
# Jinja2 model-card templates use a .md extension but contain {% ... %} /
# {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops).
exclude: ^src/lerobot/templates/.*\.md$
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
+4 -4
View File
@@ -58,7 +58,7 @@ test-act-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=4 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_freq=2 \
@@ -96,7 +96,7 @@ test-diffusion-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=2 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@@ -126,7 +126,7 @@ test-tdmpc-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=2 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@@ -161,7 +161,7 @@ test-smolvla-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=4 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_freq=2 \
+7 -10
View File
@@ -58,7 +58,7 @@ action = model.select_action(obs)
robot.send_action(action)
```
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
@@ -101,13 +101,11 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
@@ -135,7 +133,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
## Citation
@@ -143,7 +140,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
+6 -16
View File
@@ -71,21 +71,11 @@ it uses a two-step **describe → segment** flow:
2. **Segment** — that description is fed back in, and the VLM splits the
episode into consecutive atomic subtasks.
Both passes see the episode as **timestamped contact sheets** — frames
sampled at `frames_per_second` (0.5s by default) and packed into JPEG
grids with each frame's time burned into its corner, so the VLM cites
exact boundary times directly. This is far cheaper in vision tokens than
one image per frame, so the sampling can stay dense; episodes longer than
`max_frames_per_prompt` are split into windows at the same density and
merged. Both prompts also carry a causal **event-boundary** definition (a
new event starts when an object becomes held / is released / reaches a new
location / a lid changes state / contents move) to sharpen where cuts land.
The resulting spans are then stitched into a gap-free, full-episode
cover, so **every frame has exactly one active subtask**. See
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
for the production settings (single camera, timestamped contact sheets,
auto-windowed subtask generation).
for the production settings (single camera, embedded frames, windowed
subtask generation).
### Tools
@@ -172,15 +162,15 @@ Every module is on by default and can be toggled independently (set to
| Flag | Default | What it does |
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). |
| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. |
| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). |
| `--plan.frames_per_second` | `1.0` | How densely the episode video is sampled. |
| `--plan.max_video_frames` | `32` | Hard cap on frames per call (context-budget guard — don't exceed ~32 for a 32k context). |
| `--plan.subtask_window_seconds` | `0` | Split long episodes into fixed windows for constant frame density (`0` = whole episode). |
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. |
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
| `--plan.use_video_url` | `false` | Send a server-side video clip instead of embedded frames. |
### Interjections + VQA
+1 -1
View File
@@ -719,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
"num_workers": 4,
"steps": 5000,
"log_freq": 10,
"env_eval_freq": 1000,
"eval_freq": 1000,
"save_freq": 1000,
"save_checkpoint": true,
"seed": 2,
+5
View File
@@ -141,6 +141,11 @@ sample["target_message_indices"]
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
## Blends
Blend recipes select one weighted sub-recipe deterministically from the sample index.
`recipes/subtasks_vqa.yaml` trains the core blend — high-level subtask prediction, low-level execution, and VQA. `recipes/subtask_mem_vqa_speech.yaml` is the fuller variant that also adds memory updates and spoken interjection responses.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
+1 -1
View File
@@ -143,7 +143,7 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env_eval_freq=1000
--eval_freq=1000
```
## Reproducing published results
+1 -1
View File
@@ -173,7 +173,7 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env_eval_freq=1000
--eval_freq=1000
```
## Relationship to LIBERO
+2 -2
View File
@@ -120,11 +120,11 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env_eval_freq=1000
--eval_freq=1000
```
## Practical tips
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
- Adjust `batch_size`, `steps`, and `env_eval_freq` to match your compute budget.
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
+2 -2
View File
@@ -103,7 +103,7 @@ accelerate launch \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--env_eval_freq=-1 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
@@ -142,7 +142,7 @@ accelerate launch \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--env_eval_freq=-1 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
+1 -1
View File
@@ -314,7 +314,7 @@ lerobot-train \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--env_eval_freq=1000 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
+1 -1
View File
@@ -166,7 +166,7 @@ lerobot-train \
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
--steps=100000 \
--batch_size=4 \
--env_eval_freq=5000 \
--eval_freq=5000 \
--eval.batch_size=1 \
--eval.n_episodes=5 \
--save_freq=10000
+1 -1
View File
@@ -165,7 +165,7 @@ lerobot-train \
--output_dir=./outputs/smolvla_vlabench_primitive \
--steps=100000 \
--batch_size=4 \
--env_eval_freq=5000 \
--eval_freq=5000 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--save_freq=10000
+35 -3
View File
@@ -53,17 +53,49 @@ CMD = (
"export VLLM_VIDEO_BACKEND=pyav && "
"lerobot-annotate "
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated "
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated5 "
"--push_to_hub=true "
"--vlm.backend=openai "
"--vlm.model_id=Qwen/Qwen3.6-27B "
"--vlm.parallel_servers=1 "
"--vlm.num_gpus=1 "
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
"--tensor-parallel-size 1 --max-model-len 32768 "
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
"--vlm.serve_ready_timeout_s=1800 "
# Qwen3.6 ships with thinking on; annotation wants plain JSON answers.
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}'"
"--vlm.client_concurrency=128 "
"--vlm.max_new_tokens=512 "
"--vlm.temperature=0.7 "
"--executor.episode_parallelism=16 "
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
"--vlm.camera_key=observation.images.robot0_agentview_right "
# Phase 1 — plan module (subtasks + memory).
# Embed decoded frames (not a file:// clip): if clip extraction fails,
# the video_url path silently sends no video and the VLM hallucinates.
"--plan.use_video_url=false "
"--plan.frames_per_second=1.0 "
# 32 frames ≈ 8-10k vision tokens, fits the 32768 context. Don't push
# toward 128 — that overflows the context (BadRequestError 400).
"--plan.max_video_frames=32 "
# Window long episodes into 32s chunks (constant 1 fps density) so they
# get more subtasks; per-window spans are merged + stitched. 0 disables.
"--plan.subtask_window_seconds=32 "
# RoboCasa: the dataset task string is authoritative (eval uses it), so
# keep it driving subtasks. ``always`` would throw it away and hallucinate.
"--plan.derive_task_from_video=off "
# No task augmentation: eval conditions on the exact task strings, so
# rephrasings are unused at best and harmful when they drift.
"--plan.n_task_rephrasings=0 "
# Keep subtask decomposition tight for atomic tasks.
"--plan.plan_max_steps=10 "
# Only subtasks + memory — skip the numbered "plan" rows. true re-enables.
"--plan.emit_plan=false "
# The describe->segment grounding pass (+1 VLM call/episode) is ON by
# default; pass --plan.subtask_describe_first=false to skip it.
# Phase 2 — interjections + speech.
"--interjections.max_interjections_per_episode=6 "
# Phase 4 — general VQA: disabled for this run.
"--vqa.enabled=false"
)
job = run_job(
+26 -16
View File
@@ -85,6 +85,11 @@ dependencies = [
"termcolor>=2.4.0,<4.0.0",
"tqdm>=4.66.0,<5.0.0",
# Training utilities
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
# pure-python dependency — preferred over a hand-rolled implementation.
"ema-pytorch>=0.7.7,<1.0.0",
# Build tools (required by opencv-python-headless on some platforms)
"cmake>=3.29.0.1,<4.2.0",
"setuptools>=71.0.0,<81.0.0",
@@ -115,8 +120,8 @@ dataset = [
]
training = [
"lerobot[dataset]",
"wandb>=0.24.0,<0.28.0",
"lerobot[accelerate-dep]",
"accelerate>=1.10.0,<2.0.0",
"wandb>=0.24.0,<0.25.0",
]
hardware = [
"lerobot[pynput-dep]",
@@ -142,8 +147,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
sentencepiece-dep = ["sentencepiece>=0.2.0,<0.3.0"] # FAST action tokenizer backend (pi052, pi0_fast)
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
@@ -178,12 +183,7 @@ unitree_g1 = [
"lerobot[matplotlib-dep]",
"lerobot[pygame-dep]",
]
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
reachy2 = [
"reachy2_sdk>=1.0.15,<1.1.0",
"grpcio<=1.73.1",
"protobuf<=6.32.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
# leader (motorbridge-smart-servo / FashionStar UART servos).
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
@@ -203,9 +203,9 @@ wallx = [
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]", "lerobot[sentencepiece-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
"lerobot[transformers-dep]",
@@ -222,7 +222,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
@@ -244,17 +244,25 @@ annotations = [
# install it locally only if you run your own ``vllm serve``.
]
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
# are isolated so adding a new tool doesn't bloat the base install.
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
tools = [
"pocket-tts>=1.0.0,<3.0.0",
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
@@ -340,6 +348,8 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
# ---------------- Tool Configurations ----------------
@@ -35,28 +35,14 @@ class PlanConfig:
derive_task_from_video: str = "if_short"
derive_task_min_words: int = 3
# --- Frame input: timestamped contact sheets (always on) ---------------
# The subtask describe/segment passes ALWAYS render the episode as
# macrodata/refiner-style contact sheets: sampled frames packed into JPEG
# grids with each frame's timestamp burned into its corner, so the VLM
# cites the exact source time of a boundary directly. This is far cheaper
# in vision tokens than one image per frame (≈2× faster subtask generation
# in practice), which is why the sampling is dense by default.
#
# ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s.
frames_per_second: float = 2.0
# Frame budget per VLM call (= columns × rows × sheets). When a whole
# episode sampled at ``frames_per_second`` exceeds this, the episode is
# AUTOMATICALLY split into consecutive windows of
# ``max_frames_per_prompt`` frames each (one describe→segment call per
# window, still at the full ``frames_per_second`` density), and the
# per-window spans are merged + stitched into one contiguous cover. So an
# episode of any length is always covered at the full sampling density.
max_frames_per_prompt: int = 60
contact_sheet_columns: int = 5
contact_sheet_frames_per_sheet: int = 20
contact_sheet_frame_width: int = 224
contact_sheet_quality: int = 84
# Frames sampled uniformly, capped at max_video_frames — a hard context cap
# (~300 tokens/frame, so 32 fit a 32k VLM; 128 overflow).
frames_per_second: float = 1.0
max_video_frames: int = 32
# >0: split long episodes into windows of this length (constant fps density)
# instead of subsampling the whole episode; spans merged + stitched. 0 disables.
subtask_window_seconds: float = 0.0
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
@@ -68,12 +54,12 @@ class PlanConfig:
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
emit_plan: bool = True
# Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only.
# Symmetric counterpart of ``emit_plan``.
emit_memory: bool = True
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
# Send a server-side ``video_url`` clip (at use_video_url_fps) instead of embedded frames.
use_video_url: bool = False
use_video_url_fps: float = 1.0
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
@@ -197,9 +183,8 @@ class AnnotationPipelineConfig:
skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None
# Keyframe decode backend forwarded to ``decode_video_frames``. None →
# library default (torchcodec when available, else PyAV). Or pin
# ``"torchcodec"`` / ``"pyav"`` explicitly.
# Keyframe decode backend. None → ffmpeg CLI (crash-/thread-safe; torchcodec
# SIGSEGVs under concurrent decode). Or ``"torchcodec"`` / ``"pyav"``.
video_backend: str | None = None
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
@@ -24,11 +24,8 @@ querying the same timestamp pay decode cost once.
from __future__ import annotations
import io
import logging
import math
import threading
from collections.abc import Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
@@ -36,10 +33,9 @@ from typing import Any, Protocol
import PIL.Image
import torch
from lerobot.configs.video import VideoEncoderConfig
from lerobot.datasets.video_utils import decode_video_frames, reencode_video
from lerobot.datasets.video_utils import decode_video_frames
from .reader import EpisodeRecord, snap_to_frame
from .reader import EpisodeRecord
logger = logging.getLogger(__name__)
@@ -138,9 +134,10 @@ class VideoFrameProvider:
camera_key: str | None = None
tolerance_s: float = 1e-2
cache_size: int = 256
# Keyframe decode backend forwarded to
# :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None``
# uses the library default (torchcodec when available, else PyAV).
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
# concurrency- and crash-safe default for the pipeline's threaded
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
# decoder when the build is known thread-safe.
video_backend: str | None = None
_meta: Any = field(default=None, init=False, repr=False)
_cache: dict = field(default_factory=dict, init=False, repr=False)
@@ -149,10 +146,6 @@ class VideoFrameProvider:
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
# one-shot warn flag against concurrent updates from worker threads.
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
# Serializes decode_video_frames calls: torchcodec hands out one
# ``VideoDecoder`` per file from a process-wide cache, and the decoder
# is not safe to drive from multiple threads at once.
_decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
def __post_init__(self) -> None:
@@ -188,13 +181,6 @@ class VideoFrameProvider:
target = camera_key if camera_key is not None else self.camera_key
if not timestamps or target is None:
return []
# Snap each request to the nearest real frame timestamp: callers
# sample uniform grids whose points land mid-frame, and
# ``decode_video_frames`` rejects queries farther than
# ``tolerance_s`` from a decodable frame. Snapping also dedupes
# repeat queries through the cache.
if record.frame_timestamps:
timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps]
out: list[Any] = []
misses: list[float] = []
@@ -258,14 +244,15 @@ class VideoFrameProvider:
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
Returns ``None`` if the dataset has no video tracks or extraction
failed. Skips re-extract when the cached clip already exists.
Re-encodes to H.264 via
:func:`lerobot.datasets.video_utils.reencode_video` so the resulting
mp4 is decodable by every downstream video processor — stream-copy
would inherit the source codec (often AV1 in modern LeRobot
datasets), which vllm's libav build cannot decode.
Returns ``None`` if the dataset has no video tracks. Skips
re-extract when the cached clip already exists. Re-encodes to
H.264 (libx264) so the resulting mp4 is decodable by every
downstream video processor — stream-copy would inherit the
source codec (often AV1 in modern LeRobot datasets), which
vllm's libav build cannot decode.
"""
import subprocess # noqa: PLC0415
if self.camera_key is None:
return None
cache_dir.mkdir(parents=True, exist_ok=True)
@@ -276,20 +263,33 @@ class VideoFrameProvider:
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast")
cmd = [
"ffmpeg",
"-y",
"-loglevel",
"error",
"-ss",
f"{from_timestamp:.3f}",
"-to",
f"{to_timestamp:.3f}",
"-i",
str(src),
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-crf",
"23",
"-pix_fmt",
"yuv420p",
"-an",
str(out_path),
]
try:
reencode_video(
src,
out_path,
camera_encoder=encoder,
overwrite=True,
start_time_s=from_timestamp,
end_time_s=to_timestamp,
)
except Exception:
logger.warning(
"clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True
)
# ffmpeg is invoked by name via PATH lookup (the standard way to
# call the CLI); the arg list is fully controlled here, not shell.
subprocess.run(cmd, check=True, timeout=300) # nosec B607
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
return None
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
@@ -297,47 +297,61 @@ class VideoFrameProvider:
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
(torchcodec when available, PyAV otherwise; ``video_backend`` pins
one explicitly). Returns one frame per requested timestamp, or ``[]``
if decoding failed — callers treat ``[]`` as "no frames available".
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
Returns one frame per requested timestamp, or ``[]`` if decoding
failed wholesale — callers treat ``[]`` as "no frames available".
"""
ep = self._meta.episodes[episode_index]
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
try:
# The module phases decode under a ThreadPoolExecutor (see
# ``ExecutorConfig.episode_parallelism``) but torchcodec's cached
# per-file decoder is single-threaded, so serialize decodes on a
# dedicated lock. Frame extraction is a small fraction of episode
# wall time (VLM calls dominate), so the contention is cheap.
with self._decode_lock:
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
# ThreadPoolExecutor and the in-process decoders are unsafe there:
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
# (a crash no try/except can catch), PyAV can likewise segfault on
# AV1, and lerobot's ``pyav`` backend routes through the removed
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
# out per frame: each decode is an isolated child process, so it is
# both crash-safe and concurrency-safe. ``video_backend`` can pin
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
# build is safe.
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
exc: Exception | None = None
for backend in chain:
try:
if backend == "ffmpeg":
return _decode_frames_ffmpeg(video_path, shifted)
if backend in ("pyav", "av"):
return _decode_frames_av(video_path, shifted)
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
decoded = decode_video_frames(
video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
)
return list(decoded)
except Exception as exc:
# Log loudly the first time so a silent vqa-module no-op (every
# prompt skipped because frames_at returned []) is debuggable from
# the job log instead of post-hoc parquet inspection. Subsequent
# failures stay quiet.
with self._lock:
already_warned = self._warned_decode_fail
if not already_warned:
self._warned_decode_fail = True
return list(decoded)
except Exception as e: # noqa: PERF203
exc = e
# Every backend raised. Log loudly the first time so a silent
# vqa-module no-op (every prompt skipped because frames_at returned
# []) is debuggable from the job log instead of post-hoc parquet
# inspection. Subsequent failures stay quiet.
with self._lock:
already_warned = self._warned_decode_fail
if not already_warned:
logger.warning(
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s",
episode_index,
camera_key,
video_path,
self.video_backend,
exc,
exc_info=exc,
)
return []
self._warned_decode_fail = True
if not already_warned:
logger.warning(
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backends=%s: %s",
episode_index,
camera_key,
video_path,
chain,
exc,
exc_info=exc,
)
return []
def make_frame_provider(
@@ -353,6 +367,91 @@ def make_frame_provider(
return provider
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
piping a single PNG to stdout. Unlike the in-process decoders this
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
modern LeRobot datasets use) where torchcodec raises and PyAV can
SIGSEGV, and a crash stays isolated to the child process — a non-zero
exit is a catchable error, not a segfault of the whole job. Returns one
``(C, H, W)`` uint8 tensor per timestamp.
"""
import io # noqa: PLC0415
import subprocess # noqa: PLC0415
import numpy as np # noqa: PLC0415
frames: list[Any] = []
for ts in timestamps:
# ffmpeg invoked by name via PATH lookup; fully-controlled arg list, no shell.
proc = subprocess.run( # nosec B607
[
"ffmpeg",
"-nostdin",
"-loglevel",
"error",
"-ss",
f"{max(ts, 0.0):.3f}",
"-i",
str(video_path),
"-frames:v",
"1",
"-f",
"image2pipe",
"-vcodec",
"png",
"pipe:1",
],
capture_output=True,
check=True,
timeout=120,
)
if not proc.stdout:
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
return frames
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
lerobot's ``decode_video_frames(backend="pyav")`` routes through
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
fallback; this stays available behind ``video_backend="pyav"``. Returns
one ``(C, H, W)`` uint8 tensor per timestamp.
"""
import av # noqa: PLC0415
first_ts = min(timestamps)
last_ts = max(timestamps)
loaded_frames: list[torch.Tensor] = []
loaded_ts: list[float] = []
with av.open(str(video_path)) as container:
stream = container.streams.video[0]
# Seek to the keyframe at or before the first requested timestamp.
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
container.seek(offset, stream=stream, backward=True, any_frame=False)
for idx, frame in enumerate(container.decode(stream)):
ts = frame.time
if ts is None:
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
loaded_ts.append(ts)
loaded_frames.append(
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
)
if ts >= last_ts:
break
if not loaded_frames:
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
ts_tensor = torch.tensor(loaded_ts)
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
def _frame_to_pil(frame: Any) -> Any:
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
@@ -397,85 +496,3 @@ def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]
if not url:
return []
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image:
"""Burn ``timestamp`` (seconds) into the top-left corner of ``image``.
A solid black badge with white text, so a VLM reading a contact sheet can
cite the exact source time of each tile (e.g. ``012.50s``) directly,
instead of the caller having to map tile position back to time. Mirrors
the macrodata/refiner contact-sheet convention.
"""
from PIL import ImageDraw, ImageFont
result = image.copy()
draw = ImageDraw.Draw(result)
font = ImageFont.load_default()
label = f"{timestamp:06.2f}s"
left, top, right, bottom = draw.textbbox((0, 0), label, font=font)
text_w, text_h = right - left, bottom - top
pad = max(3, round(min(image.width, image.height) * 0.018))
draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0))
draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font)
return result
def to_contact_sheet_blocks(
frames: Sequence[Any],
timestamps: Sequence[float],
*,
columns: int = 5,
frames_per_sheet: int = 20,
frame_width: int = 224,
quality: int = 84,
) -> list[dict[str, Any]]:
"""Pack decoded frames into timestamped JPEG contact-sheet image blocks.
Each frame is resized to ``frame_width`` wide, stamped with its
episode-relative timestamp, and tiled row-major into grids of
``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}``
block is returned per grid; many frames collapse into a few images, so a
long episode's temporal coverage stays dense at a fraction of the vision
tokens N separate frames would cost. ``frames`` and ``timestamps`` must be
aligned and equal length. Returns ``[]`` for empty input.
"""
from PIL import Image
if not frames:
return []
columns = max(1, columns)
frames_per_sheet = max(1, frames_per_sheet)
rows_per_sheet = math.ceil(frames_per_sheet / columns)
tiles: list[PIL.Image.Image] = []
for ts, frame in zip(timestamps, frames, strict=False):
img = _frame_to_pil(frame)
if not isinstance(img, PIL.Image.Image):
continue
img = img.convert("RGB")
if img.width != frame_width:
height = max(1, round(img.height * frame_width / img.width))
img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR)
tiles.append(_draw_timestamp_badge(img, float(ts)))
if not tiles:
return []
blocks: list[dict[str, Any]] = []
for start in range(0, len(tiles), frames_per_sheet):
chunk = tiles[start : start + frames_per_sheet]
cell_w = max(tile.width for tile in chunk)
cell_h = max(tile.height for tile in chunk)
sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0))
for i, tile in enumerate(chunk):
x = (i % columns) * cell_w
y = (i // columns) * cell_h
sheet.paste(tile, (x, y))
# JPEG round-trip at ``quality`` to match the refiner convention and
# shrink the wire payload; vision-token count is set by resolution, so
# the real saving is the grid packing, not the codec.
buf = io.BytesIO()
sheet.save(buf, format="JPEG", quality=quality)
buf.seek(0)
blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")})
return blocks
@@ -20,13 +20,16 @@ from __future__ import annotations
import logging
from collections.abc import Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from ..config import PlanConfig
from ..frames import (
FrameProvider,
VideoFrameProvider,
null_provider,
to_contact_sheet_blocks,
to_video_block,
to_video_url_block,
)
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
@@ -36,44 +39,6 @@ from ..vlm_client import VlmClient
logger = logging.getLogger(__name__)
# Prepended to every describe / segment prompt so the VLM knows the images are
# timestamped contact-sheet grids, not a single video, and reads the burned-in
# per-tile timestamp when choosing boundaries.
def _contact_sheet_preamble(columns: int) -> str:
return (
"CONTACT SHEETS — how to read the images below:\n"
f"- Each image is a grid of sampled video frames, {columns} per row, "
"with time running left-to-right then top-to-bottom (row-major).\n"
"- Each frame has its timestamp burned into the top-left corner, e.g. "
'"012.50s". Use that printed timestamp (not the tile position) when you '
"choose start/end times; boundaries should land on or near a printed "
"timestamp.\n"
"- Frames continue across grids: an action may span the end of one sheet "
"and the start of the next, so do not place a boundary just because a new "
"image begins.\n\n"
)
# Appended to every describe (and segment) prompt. A visual, causal definition
# of where one event ends and the next begins — adapted from macrodata/refiner —
# to sharpen cut points while the existing prompt keeps owning the imperative
# phrasing.
_CAUSAL_BOUNDARY_RULES = (
"EVENT BOUNDARIES — where one event ends and the next begins:\n"
"- Start a new event whenever the world state changes: an object becomes "
"held (the gripper closes on it), an object is released (the gripper opens "
"and it stays put), an object reaches a new location, a lid/door/drawer "
"changes open/closed state, a tool starts or stops affecting a surface, or "
"contents visibly move (e.g. poured).\n"
"- If a single action changes the same state gradually and continuously, "
"keep it as ONE event — do not split it.\n"
"- If the same action repeats on different objects or target locations, "
"treat each repetition as a separate event.\n"
"- Do NOT create boundaries for idle time, camera motion, hesitation, or "
"tiny hand adjustments."
)
@dataclass
class PlanSubtasksMemoryModule:
"""Generate subtask spans, plan, and memory rows.
@@ -148,11 +113,9 @@ class PlanSubtasksMemoryModule:
"tool_calls": None,
}
)
# memory rows at every subtask boundary except the very first start;
# skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only).
# memory rows at every subtask boundary except the very first start
prior_memory = ""
memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else []
for i, span in memory_boundaries:
for i, span in enumerate(subtask_spans[1:], start=1):
completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
@@ -257,13 +220,7 @@ class PlanSubtasksMemoryModule:
prompt: str,
window: tuple[float, float] | None = None,
) -> list[dict[str, Any]]:
"""User message combining the (optionally windowed) contact sheets with ``prompt``.
The prompt is always prefixed with a short explanation of how to read
the timestamped grids, so the model treats them as one ordered
sequence of frames rather than unrelated images.
"""
prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt
"""User message combining the (optionally windowed) video block with ``prompt``."""
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]
@@ -336,19 +293,24 @@ class PlanSubtasksMemoryModule:
def _episode_video_block(
self, record: EpisodeRecord, window: tuple[float, float] | None = None
) -> list[dict[str, Any]]:
"""Timestamped contact sheets for the describe / segmentation prompts.
"""Video block for the segmentation / describe prompts.
Always renders the (optionally windowed) episode as contact sheets:
frames sampled at ``frames_per_second`` and packed into timestamped
JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole
episodes that exceed it are windowed upstream in
:meth:`_generate_subtasks` so each call stays within budget while the
full episode keeps its sampling density.
Always returns a block that actually carries the video. When
``use_video_url`` is set we try the server-side ``video_url``
path first, but if clip extraction fails we FALL BACK to
decoding + embedding frames rather than returning an empty
block — an empty block would leave the VLM with no visual
grounding at all and it would hallucinate subtasks purely from
the task text.
When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE
(``ts - w0``) to match the window-relative time frame the
segmentation prompt works in (spans are offset back to absolute time
afterwards).
When ``window=(w0, w1)`` is given (windowed subtask generation,
``subtask_window_seconds > 0``), embed frames sampled at the FIXED
``frames_per_second`` rate within ``[w0, w1]`` — constant temporal
density regardless of episode length, so long episodes are split
into windows rather than subsampled to a sparse 32-frame whole-
episode view. The ``video_url`` path is skipped for windows (it is
a whole-episode clip). ``max_video_frames`` still caps each window
as a context-budget safety net.
"""
if not record.frame_timestamps:
return []
@@ -356,44 +318,28 @@ class PlanSubtasksMemoryModule:
w0, w1 = float(window[0]), float(window[1])
dur = max(0.0, w1 - w0)
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_frames_per_prompt)
n = min(n, self.config.max_video_frames)
if n <= 1 or dur <= 0.0:
timestamps = [0.5 * (w0 + w1)]
else:
step = dur / (n - 1)
timestamps = [w0 + i * step for i in range(n)]
frames = self.frame_provider.frames_at(record, timestamps)
rel = [ts - w0 for ts in timestamps[: len(frames)]]
return self._contact_sheet_blocks(frames, rel)
return to_video_block(self.frame_provider.frames_at(record, timestamps))
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
clip = self.frame_provider.episode_clip_path(record, cache_dir)
if clip is not None:
return to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
logger.warning(
"episode %d: video_url clip extraction failed — falling back to "
"embedded frames so the VLM still sees the demonstration",
record.episode_index,
)
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1)
n = min(n, self.config.max_frames_per_prompt)
timestamps = self._uniform_episode_timestamps(record, n)
frames = self.frame_provider.frames_at(record, timestamps)
return self._contact_sheet_blocks(frames, timestamps[: len(frames)])
@staticmethod
def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]:
"""``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly."""
ts = record.frame_timestamps
if n >= len(ts):
return [float(t) for t in ts]
t0, t_last = float(ts[0]), float(ts[-1])
if t_last <= t0 or n <= 1:
return [t0] * max(1, n)
step = (t_last - t0) / (n - 1)
return [t0 + i * step for i in range(n)]
def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]:
"""Build timestamped contact-sheet image blocks from decoded frames."""
return to_contact_sheet_blocks(
frames,
timestamps,
columns=self.config.contact_sheet_columns,
frames_per_sheet=self.config.contact_sheet_frames_per_sheet,
frame_width=self.config.contact_sheet_frame_width,
quality=self.config.contact_sheet_quality,
)
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
target_count = min(target_count, self.config.max_video_frames)
video_frames = self.frame_provider.video_for_episode(record, target_count)
return to_video_block(video_frames)
def run_plan_updates(
self,
@@ -459,17 +405,12 @@ class PlanSubtasksMemoryModule:
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
effective_task = task if task is not None else record.episode_task
# ---- Auto-windowing (keeps the full sampling density) --------
# Contact sheets are cheap, but a whole long episode sampled at
# ``frames_per_second`` can still exceed ``max_frames_per_prompt``.
# When it does, split into consecutive windows of exactly that many
# frames (one describe→segment call each, still at the full sampling
# density), then merge + stitch — so an episode of any length is
# covered at full density rather than subsampled into one sparse call.
fps = max(1e-6, float(self.config.frames_per_second))
n_whole = int(round(episode_duration * fps)) + 1
if n_whole > self.config.max_frames_per_prompt:
window_s = self.config.max_frames_per_prompt / fps
# ---- Windowed path (constant temporal density) ---------------
# If subtask_window_seconds > 0 and the episode exceeds one window,
# process fixed-length windows so the VLM always sees
# frames_per_second density; results are merged + stitched.
window_s = float(getattr(self.config, "subtask_window_seconds", 0.0) or 0.0)
if window_s > 0.0 and episode_duration > window_s:
return self._generate_subtasks_windowed(record, effective_task, window_s)
# ---- Pass 1 (optional): grounding description ----------------
@@ -487,14 +428,12 @@ class PlanSubtasksMemoryModule:
)
# ---- Pass 2: segmentation ------------------------------------
prompt = self._with_causal_rules(
load_prompt("plan_subtasks").format(
episode_task=effective_task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
observation_block=observation_block,
)
prompt = load_prompt("plan_subtasks").format(
episode_task=effective_task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
observation_block=observation_block,
)
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
cleaned = self._clean_spans(spans, record)
@@ -569,14 +508,12 @@ class PlanSubtasksMemoryModule:
"action that is not in your description above.\n\n"
)
prompt = self._with_causal_rules(
load_prompt("plan_subtasks").format(
episode_task=task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{win_len:.3f}",
observation_block=observation_block,
)
prompt = load_prompt("plan_subtasks").format(
episode_task=task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{win_len:.3f}",
observation_block=observation_block,
)
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
# Window-relative clamp; no frame-snap dedupe yet (done on the
@@ -623,11 +560,6 @@ class PlanSubtasksMemoryModule:
s["end"] = float(s["start"])
return spans
@staticmethod
def _with_causal_rules(prompt: str) -> str:
"""Append the causal event-boundary rules to a describe/segment prompt."""
return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}"
def _clean_spans(
self,
spans: Any,
@@ -675,7 +607,7 @@ class PlanSubtasksMemoryModule:
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
) -> str:
"""Grounding pass: free-form chronological description of the (windowed) video."""
prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task))
prompt = load_prompt("plan_subtask_describe").format(episode_task=task)
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
return text.strip() if isinstance(text, str) and text.strip() else ""
@@ -310,19 +310,6 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
return _GenericTextClient(_gen, config)
def _bind_serve_port(cmd: str, port: int) -> str:
"""Bind a serve command to ``port``: substitute a ``{port}`` placeholder
if present, else append ``--port`` when the command omits it (leaving an
explicit ``--port`` untouched). Shared by the single- and parallel-server
paths so a serve_command never reaches the server with a literal
``{port}``."""
if "{port}" in cmd:
return cmd.replace("{port}", str(port))
if "--port" not in cmd:
return f"{cmd} --port {port}"
return cmd
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
"""Spawn ``config.parallel_servers`` independent vllm replicas.
@@ -365,7 +352,7 @@ def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
gpu = i % num_gpus
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
cmd = _bind_serve_port(base_cmd, port)
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
api_base = f"http://localhost:{port}/v1"
api_bases.append(api_base)
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
@@ -464,11 +451,6 @@ def _spawn_inference_server(config: VlmConfig) -> str:
f"transformers serve {shlex.quote(config.model_id)} "
f"--port {config.serve_port} --continuous-batching"
)
# Bind the single server to ``serve_port`` (what ``api_base`` below
# targets): substitute a literal ``{port}`` placeholder, else append
# ``--port``. Without this a serve_command carrying ``{port}`` would
# reach the server unsubstituted and fail to parse.
cmd = _bind_serve_port(cmd, config.serve_port)
api_base = f"http://localhost:{config.serve_port}/v1"
print(f"[server] launching: {cmd}", flush=True)
proc = subprocess.Popen(
+4 -37
View File
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
) -> None:
state: dict = {"step": step}
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
# produced `step`, since both scale how many sampler positions a step consumes (see
# compute_sampler_state).
if num_processes is not None:
state["num_processes"] = num_processes
if batch_size is not None:
state["batch_size"] = batch_size
write_json(state, save_dir / TRAINING_STEP)
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
return training_step["step"]
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
@@ -96,8 +75,6 @@ def save_checkpoint(
scheduler: LRScheduler | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -123,10 +100,6 @@ def save_checkpoint(
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
num_processes (int | None, optional): Distributed world size to record for sample-exact
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
@@ -139,9 +112,7 @@ def save_checkpoint(
preprocessor.save_pretrained(pretrained_dir)
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
def save_training_state(
@@ -149,8 +120,6 @@ def save_training_state(
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
@@ -162,12 +131,10 @@ def save_training_state(
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_training_step(train_step, save_dir)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
+146
View File
@@ -205,3 +205,149 @@ class WandBLogger:
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
def log_training_examples(
self,
batch: dict,
step: int,
*,
camera_keys: list[str],
n_samples: int = 4,
policy=None,
predict_actions: bool = False,
mode: str = "train",
) -> None:
"""Push a ``wandb.Table`` of training-example rows for the current batch.
Each row is one batch element with:
* one ``wandb.Image`` column per camera in ``camera_keys`` (CHW or
HWC, uint8 or float in [0,1] — auto-detected),
* any text fields present in the batch (``task`` / ``subtask`` /
``memory`` / ``instruction``),
* ground-truth action first/last frame (the action chunk's
endpoints — gives a quick sense of trajectory direction),
* if ``predict_actions=True`` and ``policy`` is supplied, the model's
``predict_action_chunk`` first/last frame alongside.
This is opt-in via ``--wandb.log_examples_freq=N`` on the CLI; the
training loop calls it once every N steps. Cheap to keep on: with
N=4 samples and 3 cameras you upload 12 small PNGs per dump and (if
enabled) run one extra inference forward pass.
"""
import logging # noqa: PLC0415
import numpy as np # noqa: PLC0415
import torch # noqa: PLC0415
if mode not in {"train", "eval"}:
raise ValueError(mode)
# Batch size — first tensor-like value wins.
bsz = next(
(int(v.shape[0]) for v in batch.values() if hasattr(v, "shape") and v.ndim > 0),
None,
)
if not bsz:
return
n = min(int(n_samples), bsz)
# Optional predicted-action forward pass on the first n samples.
pred_actions: np.ndarray | None = None
if predict_actions and policy is not None:
was_training = policy.training
try:
policy.eval()
sub_batch = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
sub_batch[k] = v[:n]
elif isinstance(v, (list, tuple)):
sub_batch[k] = list(v[:n])
else:
sub_batch[k] = v
with torch.no_grad():
pred = policy.predict_action_chunk(sub_batch)
pred_actions = pred.detach().cpu().float().numpy()
except Exception as exc: # noqa: BLE001
logging.warning(
"log_training_examples: predict_action_chunk failed (%s) — "
"skipping predicted-action columns",
exc,
)
pred_actions = None
finally:
if was_training:
policy.train()
present_cameras = [c for c in camera_keys if c in batch]
text_keys = [k for k in ("task", "subtask", "memory", "instruction") if k in batch]
columns = ["sample"]
columns.extend(c.removeprefix("observation.images.") or c for c in present_cameras)
columns.extend(text_keys)
columns.append("gt_action_first")
columns.append("gt_action_last")
if pred_actions is not None:
columns.append("pred_action_first")
columns.append("pred_action_last")
table = self._wandb.Table(columns=columns)
def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray:
# Strip an outer time dim if present: (T, C, H, W) -> first frame.
if t.ndim == 4:
t = t[0]
# CHW -> HWC.
if t.ndim == 3 and t.shape[0] in (1, 3, 4) and t.shape[-1] not in (1, 3, 4):
t = t.permute(1, 2, 0)
arr = t.detach().cpu().float().numpy()
if arr.size and float(arr.max()) <= 1.5:
arr = arr * 255.0
return np.clip(arr, 0, 255).astype(np.uint8)
def _action_endpoints(a: torch.Tensor) -> tuple[str, str]:
arr = a.detach().cpu().float().numpy()
if arr.ndim == 2: # (T, D)
return (
str(np.round(arr[0], 3).tolist()),
str(np.round(arr[-1], 3).tolist()),
)
if arr.ndim == 1:
rounded = np.round(arr, 3).tolist()
return (str(rounded), str(rounded))
return (str(arr.tolist()), str(arr.tolist()))
for i in range(n):
row: list = [i]
for cam in present_cameras:
try:
row.append(self._wandb.Image(_to_uint8_hwc(batch[cam][i])))
except Exception as exc: # noqa: BLE001
logging.warning(
"log_training_examples: camera %s sample %d failed (%s)",
cam,
i,
exc,
)
row.append(None)
for tk in text_keys:
v = batch[tk]
if isinstance(v, (list, tuple)):
row.append(str(v[i]) if i < len(v) else "")
else:
row.append(str(v))
action = batch.get("action")
if isinstance(action, torch.Tensor) and action.ndim >= 1:
first, last = _action_endpoints(action[i])
row.append(first)
row.append(last)
else:
row.append("")
row.append("")
if pred_actions is not None:
p = torch.from_numpy(pred_actions[i])
pfirst, plast = _action_endpoints(p)
row.append(pfirst)
row.append(plast)
table.add_data(*row)
self._wandb.log({f"{mode}/examples": table}, step=step)
+66 -2
View File
@@ -39,8 +39,6 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
eval_split: float = 0.0
def __post_init__(self) -> None:
if self.episodes is not None:
@@ -64,6 +62,72 @@ class WandBConfig:
run_id: str | None = None
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
add_tags: bool = True # If True, save configuration as tags in the WandB run.
# Periodic training-example dump (independent of ``log_freq``). When > 0,
# every ``log_examples_freq`` steps the trainer pushes a ``wandb.Table``
# with one row per sampled batch element containing each camera view
# (rendered as ``wandb.Image``), any text fields present in the batch
# (``task`` / ``subtask`` / ``memory`` / ``instruction``), and the
# ground-truth action chunk's first + last frames. Defaults to 5000 — set
# to 0 to disable. Only fires when ``enable=True``, so runs without wandb
# are unaffected.
log_examples_freq: int = 5000
# Number of batch elements to include in each example dump.
log_examples_n: int = 4
# If True (default), also run ``policy.predict_action_chunk`` on the logged
# samples (in eval mode, no_grad) and add predicted vs ground-truth action
# columns to the table. Costs one extra forward pass per dump — negligible
# at the 5k-step default cadence. Set to ``False`` if your policy doesn't
# implement ``predict_action_chunk`` or you want to skip the extra forward.
log_examples_predict_actions: bool = True
@dataclass
class EMAConfig:
"""Exponential Moving Average of trainable policy parameters.
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
pi052) benefit substantially from averaging late-training
parameter oscillations — see Chi et al. 2023 §V.D. The official
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
``0.999`` for its pi05_libero config; the openpi PyTorch port
explicitly lists EMA as unsupported, and LeRobot main inherited
that gap. Enabling this flag plugs ema-pytorch
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
training loop with a shadow ``nn.Module`` clone of the policy.
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
params) + one elementwise update per training step (~1% step time).
Off by default (opt-in): EMA is only beneficial for flow-matching /
diffusion policies (pi0/pi05/pi052), and the fp32 shadow copy is pure
overhead for other policies (e.g. VLA-JEPA). Set ``--ema.enable=true``
to turn it on (the pi05/pi052 training recipes do this). openpi (JAX)
ships EMA on for every config; enable it explicitly to match that.
"""
enable: bool = False
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
# ema-pytorch as ``beta``).
# 0.999 — last ~1000 steps; pi05_libero default in openpi
# 0.99 — last ~100 steps; openpi top-level default
# 0.75 — very fast EMA (Diffusion Policy original setting)
# 0.9999 — very slow EMA (long classification runs)
decay: float = 0.99
# Skip the first N calls to ``ema.update()``; during this window
# the shadow is just a hard copy of the live weights (no averaging).
# Lets early-training rapid changes settle before averaging begins.
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
# ramp like older lerobot EMA implementations).
warmup_steps: int = 0
# When True, the periodic eval block uses the EMA shadow model
# directly (``ema.ema_model``) instead of the live policy. Standard
# practice for diffusion-style policies — eval scores are usually
# 13% higher than the live policy at the same step.
use_for_eval: bool = True
# When True, the periodic wandb training-example dump uses the EMA
# shadow for the optional predicted-action columns (so what you see
# in W&B matches eval behavior).
use_for_wandb_examples: bool = True
@dataclass
+18 -3
View File
@@ -147,7 +147,16 @@ class TrainingRecipe:
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
"""Ensure every templated binding is known and the recipe supervises something.
A recipe is valid if it has at least one of:
* a ``target: true`` assistant turn (drives text-CE supervision), or
* a ``stream: low_level`` turn (drives flow / action supervision via
``predict_actions=True``, even when no assistant turn is targeted —
e.g. π0.5-style ``low_level_execution`` where the action expert
conditions on a user-only ``${subtask}`` prompt).
"""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
@@ -156,8 +165,14 @@ class TrainingRecipe:
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
has_target = any(turn.target for turn in self.messages)
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
if not (has_target or has_low_level):
raise ValueError(
"Message recipes must contain at least one supervised turn — "
"either ``target: true`` (text CE) or ``stream: low_level`` "
"(flow/action loss)."
)
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
@@ -0,0 +1,68 @@
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
#
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
# training, and adds two text-supervised tasks:
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# memory_update — compress progress into a memory note.
# user_interjection_response — reply to a user interjection with a
# spoken `say` tool call (no plan, no
# subtask text — just the spoken reply).
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# Plan is intentionally left out — memory is the only persistent
# high-level state here, keeping the prompt short.
#
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
# annotations (the annotation pipeline's memory + interjection modules)
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
high_level_subtask:
weight: 0.30
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.55
messages:
# The action expert is conditioned on the SUBTASK — at inference
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
# here. `stream: low_level` flips `predict_actions=True` so the
# flow loss fires; no text-CE target (subtask prediction is owned
# by `high_level_subtask`).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# At inference, `MemoryUpdateFwd` is triggered only on
# `subtask_change` events (sparse). Training densely with
# `active_at` — i.e. on every frame inside a subtask interval,
# not just the boundary frame — supervises the same
# (prior_memory, completed_subtask) → current_memory mapping
# against varied observations within the interval. The model
# learns a stateless transformation; the *when* to emit lives in
# the inference trigger, not the model. Annotations only exist
# for ~1% of frames as boundary events, so `emitted_at` would
# waste 99% of the blend draws (and silently leak them into a
# task-conditioned fallback); `active_at` lifts the renderable
# rate to ~87% on this dataset.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
@@ -0,0 +1,99 @@
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
#
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
# camera-grounded VQA across the three RoboCasa camera keys produced
# by ``slurm_build_robocasa_composite_seen.py``:
#
# observation.images.robot0_agentview_left (left scene view)
# observation.images.robot0_agentview_right (right scene view)
# observation.images.robot0_eye_in_hand (wrist)
#
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
# VQA per camera, so each anchor frame produces three (user, assistant)
# rows tagged with their source camera. Each VQA sub-recipe consumes
# the rows for one camera via ``camera=...`` resolver bindings.
#
# Spatial VQA targets (bbox / point) are rewritten from JSON to
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
# ``register_paligemma_loc_tokens`` already collapses them to single
# detection-vocab ids so the LM head learns the pretrained pointing /
# detection prior, not a 7-piece BPE salad.
#
# Interjections / spoken responses are intentionally absent — the
# annotation job runs with ``--interjections.enabled=false``.
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.45
messages:
# Action expert is conditioned on the SUBTASK; at inference the
# high-level loop generates it via the LM head and feeds it here.
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
# loss fires; subtask CE is owned by ``high_level_subtask``.
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# Trained densely with ``active_at`` — every frame inside a subtask
# interval — so the (prior_memory, completed_subtask) → current_memory
# mapping is supervised against varied observations. The *when* to
# emit lives in the inference trigger (subtask_change), not the
# model. See ``subtask_mem.yaml`` for the long version of this note.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
ask_vqa_agentview_left:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_left}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_agentview_right:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_right}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_eye_in_hand}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
@@ -0,0 +1,114 @@
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
#
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
# training, and adds two text-supervised tasks:
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# memory_update — compress progress into a memory note.
# user_interjection_response — reply to a user interjection with a
# spoken `say` tool call (no plan, no
# subtask text — just the spoken reply).
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# Plan is intentionally left out — memory is the only persistent
# high-level state here, keeping the prompt short.
#
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
# annotations (the annotation pipeline's memory + interjection modules)
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.40
messages:
# The action expert is conditioned on the SUBTASK — at inference
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
# here. `stream: low_level` flips `predict_actions=True` so the
# flow loss fires; no text-CE target (subtask prediction is owned
# by `high_level_subtask`).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# At inference, `MemoryUpdateFwd` is triggered only on
# `subtask_change` events (sparse). Training densely with
# `active_at` — i.e. on every frame inside a subtask interval,
# not just the boundary frame — supervises the same
# (prior_memory, completed_subtask) → current_memory mapping
# against varied observations within the interval. The model
# learns a stateless transformation; the *when* to emit lives in
# the inference trigger, not the model. Annotations only exist
# for ~1% of frames as boundary events, so `emitted_at` would
# waste 99% of the blend draws (and silently leak them into the
# task-conditioned fallback); `active_at` lifts the renderable
# rate to ~87% on Hi-Robot-style datasets.
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.10
bindings:
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
# Spoken reply only: the assistant turn carries no text content,
# just a `say` tool call (`tool_calls_from: speech`). The chat
# tokenizer flattens it to a `<say>...</say>` marker, so the
# supervised target trains the model to respond to an
# interjection with a spoken acknowledgement.
- {role: assistant, stream: high_level, target: true, if_present: speech, tool_calls_from: speech}
# VQA is view-dependent — each camera gets its own sub-recipe so the
# resolver disambiguates via `camera=...`. Camera keys match
# subtasks_vqa.yaml (`front` + `wrist`); adjust to your dataset.
ask_vqa_top:
weight: 0.075
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.075
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
@@ -0,0 +1,61 @@
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
#
# Trains two things only: subtasks and VQA. Plan and memory are
# intentionally left out — keeps the prompt short and the training
# surface small. The fuller blend with memory + spoken replies is
# ``subtask_mem_vqa_speech.yaml``.
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# PI052's text tokenizer renders these messages as plain
# ``Role: content`` text (PaliGemma is not chat-pretrained).
blend:
high_level_subtask:
weight: 0.40
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.40
messages:
# The action expert is conditioned on the SUBTASK — at inference
# the high-level loop (``HighLevelSubtaskFwd``) generates the
# subtask via the LM head and feeds it here. The action expert's
# prefix is [images, subtask, state]. ``stream: low_level`` flips
# ``predict_actions=True`` so the flow loss fires; no text-CE
# target here (subtask prediction is owned by
# ``high_level_subtask``).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+14 -8
View File
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from . import parser
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
@@ -100,13 +100,8 @@ class TrainPipelineConfig(HubMixin):
prefetch_factor: int = 4
persistent_workers: bool = True
steps: int = 100_000
# Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
env_eval_freq: int = 20_000
eval_freq: int = 20_000
log_freq: int = 200
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
eval_steps: int = 0
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
max_eval_samples: int = 0
tolerance_s: float = 1e-4
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
@@ -116,9 +111,20 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
ema: EMAConfig = field(default_factory=EMAConfig)
peft: PeftConfig | None = None
# Sample weighting configuration (e.g., for RA-BC training)
# VQA oversampling. When set (a fraction in (0, 1)), the training
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
# carrying a `vqa` language annotation often enough that they make
# up roughly this fraction of the training stream. VQA annotations
# are typically sparse, so without this they are underrepresented.
# `None` (default) keeps uniform episode-aware sampling.
vqa_target_fraction: float | None = None
# Sample weighting configuration (e.g., for RA-BC training). Old
# inline ``use_rabc`` / ``rabc_*`` params are migrated to this
# field by ``_migrate_legacy_rabc_keys`` above.
sample_weighting: SampleWeightingConfig | None = None
# Rename map for the observation to override the image and state keys
+15 -4
View File
@@ -35,7 +35,6 @@ from .dataset_tools import (
remove_feature,
split_dataset,
)
from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
@@ -50,11 +49,24 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler, compute_sampler_state
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
def make_dataset(*args, **kwargs):
from .factory import make_dataset as _make_dataset
return _make_dataset(*args, **kwargs)
def resolve_delta_timestamps(*args, **kwargs):
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
return _resolve_delta_timestamps(*args, **kwargs)
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
# and legacy migration constants are intentionally NOT re-exported here.
# Import directly: ``from lerobot.datasets.io_utils import ...``
@@ -65,6 +77,7 @@ __all__ = [
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"WeightedEpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
@@ -82,14 +95,12 @@ __all__ = [
"aggregate_stats",
"convert_image_to_video_dataset",
"create_initial_features",
"compute_sampler_state",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"get_feature_stats",
"load_episodes",
"make_dataset",
"make_train_eval_datasets",
"merge_datasets",
"modify_features",
"modify_tasks",
+6 -21
View File
@@ -286,8 +286,6 @@ def aggregate_datasets(
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
chunk_size: int | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -305,8 +303,6 @@ def aggregate_datasets(
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
logging.info("Start aggregate_datasets")
@@ -355,12 +351,8 @@ def aggregate_datasets(
dst_meta.episodes = {}
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos
)
data_idx = aggregate_data(
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data
)
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
@@ -375,9 +367,7 @@ def aggregate_datasets(
logging.info("Aggregation complete.")
def aggregate_videos(
src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size, concatenate_videos=True
):
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
"""Aggregates video chunks from a source dataset into the destination dataset.
Handles video file concatenation and rotation based on file size limits.
@@ -389,7 +379,6 @@ def aggregate_videos(
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
@@ -450,7 +439,7 @@ def aggregate_videos(
src_size = get_file_size_in_mb(src_path)
dst_size = get_file_size_in_mb(dst_path)
if not concatenate_videos or dst_size + src_size >= video_files_size_in_mb:
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new file - offset is 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_key = (chunk_idx, file_idx)
@@ -488,7 +477,7 @@ def aggregate_videos(
return videos_idx
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size, concatenate_data=True):
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
"""Aggregates data chunks from a source dataset into the destination dataset.
Reads source data files, updates indices to match the aggregated dataset,
@@ -504,7 +493,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
data_idx: Dictionary tracking data chunk and file indices.
data_files_size_in_mb: Maximum size for data files in MB.
chunk_size: Maximum number of files per chunk.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
Returns:
dict: Updated data_idx with current chunk and file indices.
@@ -550,7 +538,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
contains_images=contains_images,
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
)
# Record the mapping from source to actual destination
@@ -627,7 +614,6 @@ def append_or_create_parquet_file(
contains_images: bool = False,
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -644,7 +630,6 @@ def append_or_create_parquet_file(
contains_images: Whether the data contains images requiring special handling.
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -664,7 +649,7 @@ def append_or_create_parquet_file(
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
if not concatenate or dst_size + src_size >= max_mb:
if dst_size + src_size >= max_mb:
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
-2
View File
@@ -59,8 +59,6 @@ class RunningQuantileStats:
batch: An array where all dimensions except the last are batch dimensions.
"""
batch = batch.reshape(-1, batch.shape[-1])
# Promote integer and low-precision inputs before computing squared statistics.
batch = batch.astype(np.result_type(batch.dtype, np.float32), copy=False)
num_elements, vector_length = batch.shape
if self._count == 0:
+43
View File
@@ -126,10 +126,53 @@ class DatasetReader:
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
# Datasets annotated with the PR1 language columns may have been
# written without registering those columns in ``meta/info.json``
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
# back-filled by ``lerobot-annotate``). Probe a single parquet
# shard and graft the column features on so the strict
# ``Dataset.from_parquet`` cast doesn't fail with
# ``column names don't match``.
features = self._extend_features_with_language_columns(features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _extend_features_with_language_columns(
self, features: datasets.Features
) -> datasets.Features:
"""Add ``language_persistent`` / ``language_events`` to ``features``
when the underlying parquet shards declare them but the metadata
doesn't. No-op when neither column is present or both are
already registered.
"""
# Find any one parquet to peek at; bail if there are none yet
# (the dataset will fail later for an unrelated reason and we
# want that error to surface as-is).
try:
sample = next((self.root / "data").glob("*/*.parquet"))
except StopIteration:
return features
from pyarrow import parquet as _pq # noqa: PLC0415
schema_names = set(_pq.read_schema(sample).names)
from .language import ( # noqa: PLC0415
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
language_events_column_feature,
language_persistent_column_feature,
)
extra: dict[str, object] = {}
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
extra[LANGUAGE_EVENTS] = language_events_column_feature()
if not extra:
return features
return datasets.Features({**features, **extra})
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
-6
View File
@@ -261,8 +261,6 @@ def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
concatenate_videos: bool = True,
concatenate_data: bool = True,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
@@ -272,8 +270,6 @@ def merge_datasets(
datasets: List of LeRobotDatasets to merge.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
concatenate_videos: When False, keep one mp4 per source file instead of packing into shards.
concatenate_data: When False, keep one parquet per source file instead of packing into shards.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -288,8 +284,6 @@ def merge_datasets(
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
concatenate_videos=concatenate_videos,
concatenate_data=concatenate_data,
)
merged_dataset = LeRobotDataset(
-79
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from pprint import pformat
import torch
@@ -131,81 +130,3 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset
def make_train_eval_datasets(
cfg: TrainPipelineConfig,
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
"""Create train and optional eval datasets by splitting episodes based on eval_split.
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
If eval_split == 0.0, returns (full_dataset, None).
"""
full_dataset = make_dataset(cfg)
if cfg.dataset.eval_split == 0.0:
return full_dataset, None
base_episodes = (
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
)
episode_tasks = full_dataset.meta.episodes["tasks"]
task_to_episodes: dict[str, list[int]] = {}
for ep_idx in base_episodes:
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
task_to_episodes.setdefault(task_key, []).append(ep_idx)
train_episodes, eval_episodes = [], []
for eps in task_to_episodes.values():
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
train_episodes.extend(eps[: len(eps) - n_eval])
eval_episodes.extend(eps[len(eps) - n_eval :])
if not train_episodes:
raise ValueError(
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
)
logging.info(
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
)
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
train_image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
train_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=train_episodes,
delta_timestamps=delta_timestamps,
image_transforms=train_image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
eval_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=eval_episodes,
delta_timestamps=delta_timestamps,
image_transforms=None,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
if cfg.dataset.use_imagenet_stats:
for ds in (train_dataset, eval_dataset):
for key in ds.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return train_dataset, eval_dataset
+89 -3
View File
@@ -170,6 +170,29 @@ def render_sample(
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
# VQA-priority routing. A ``vqa`` annotation is sparse and
# view-dependent; the plain weighted blend would (a) waste a draw
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
# sub-recipes and *this* frame carries one of their VQA bindings,
# render VQA here regardless of the weighted draw. That makes VQA's
# recipe-side training share equal the VQA-annotation density (the
# maximum reachable without a dataset-level oversampling sampler).
if recipe.blend is not None:
vqa_rendered = _render_vqa_if_present(
recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
if vqa_rendered is not None:
return vqa_rendered
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
@@ -183,6 +206,59 @@ def render_sample(
return _render_message_recipe(selected_recipe, bindings)
def _render_vqa_if_present(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> RenderedMessages | None:
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
annotation; otherwise return ``None`` so the caller falls back to the
normal weighted blend.
When several VQA sub-recipes resolve (e.g. a frame annotated for more
than one camera), one is chosen deterministically by relative weight.
"""
assert recipe.blend is not None
renderable: list[tuple[float, RenderedMessages]] = []
for name, component in recipe.blend.items():
if not name.startswith("ask_vqa"):
continue
bindings = _resolve_bindings(
component,
persistent=persistent,
events=events,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
rendered = _render_message_recipe(component, bindings)
if rendered is not None:
renderable.append((float(component.weight or 0.0), rendered))
if not renderable:
return None
if len(renderable) == 1:
return renderable[0][1]
# Multiple cameras have a VQA for this frame — deterministic pick by
# relative weight (fall back to a uniform draw if all weights are 0).
total = sum(w for w, _ in renderable) or float(len(renderable))
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total
cumulative = 0.0
for w, rendered in renderable:
cumulative += w or (total / len(renderable))
if draw < cumulative:
return rendered
return renderable[-1][1]
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
@@ -346,7 +422,15 @@ def _render_message_recipe(
if turn.target:
target_indices.append(message_idx)
if not target_indices:
# A render is meaningful if it supervises *something*: either a
# text-CE target turn, or a ``low_level`` stream turn (flow / action
# supervision — e.g. the flow-only ``low_level_execution`` recipe,
# ``user(${subtask})`` with ``stream: low_level`` and no target).
# Without this, a flow-only recipe renders to ``None`` every time
# the blend draws it → ``predict_actions`` is never True → the
# action expert never receives a flow loss.
has_low_level = any(stream == "low_level" for stream in streams)
if not target_indices and not has_low_level:
return None
rendered = {
@@ -403,8 +487,10 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
# Valid iff it supervises something: a text-CE target turn OR a
# ``low_level`` stream turn (flow / action supervision).
if not target_indices and not any(s == "low_level" for s in streams):
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
-2
View File
@@ -474,8 +474,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
if reader.hf_dataset is None:
# One-shot load after finalize()
reader.load_and_activate()
if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx:
idx = reader._absolute_to_relative_idx[idx]
return reader.get_item(idx)
def select_columns(self, column_names: str | list[str]):
+91 -118
View File
@@ -14,36 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections.abc import Iterator
import numpy as np
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
"""Sampler over episode frames that stores only per-episode boundaries.
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
instead of materializing a Python list of every frame index.
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
resumed epoch, `__len__` still reports the full length.
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
starts (e.g. on resume or in tests).
"""
def __init__(
self,
dataset_from_indices: list[int],
@@ -52,125 +30,120 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
seed: int = 0,
):
"""
"""Sampler that optionally incorporates episode boundary information.
Args:
dataset_from_indices: Start index of each episode in the dataset.
dataset_to_indices: End index of each episode in the dataset.
episode_indices_to_use: Episode indices to use; None means all.
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
seed: Seed the permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
if from_indices.shape != to_indices.shape:
raise ValueError(
f"dataset_from_indices and dataset_to_indices must have the same length, "
f"got {len(from_indices)} and {len(to_indices)}"
)
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
used = np.ones(len(from_indices), dtype=bool)
if episode_indices_to_use is not None:
used = np.zeros(len(from_indices), dtype=bool)
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
starts = from_indices + drop_n_first_frames
lengths = to_indices - drop_n_last_frames - starts
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
to_indices[episode_idx] - from_indices[episode_idx],
drop_n_first_frames,
drop_n_last_frames,
)
used &= lengths > 0
if not used.any():
if not indices:
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self._starts = starts[used]
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.indices = indices
self.shuffle = shuffle
self.seed = seed
self._epoch = 0
self._start_index = 0
@property
def indices(self) -> list[int]:
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
self._epoch = epoch
def state_dict(self) -> dict:
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _epoch_generator(self, epoch: int) -> torch.Generator:
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
# and reproduces identically on every rank without touching the global RNG.
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
return torch.Generator().manual_seed(epoch_seed)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_epoch(epoch, start)
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
if self.shuffle:
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
for k in range(start, self._num_frames):
yield self._frame_index(int(order[k]))
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for k in range(start, self._num_frames):
yield self._frame_index(k)
for i in self.indices:
yield i
def __len__(self) -> int:
return self._num_frames
return len(self.indices)
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
proportion to per-frame weights.
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
Assumptions (resume is only sample-exact when they hold):
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
many positions a step consumes, so the epoch/offset are wrong if either changed. The
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
the boundary is off.
Used to oversample frames carrying a sparse annotation (e.g. a VQA
question) so the policy sees them more often than their natural
dataset density. One epoch still yields ``len(self.indices)``
samples — the weights only change the *composition* of the stream,
not its length. Each epoch re-draws, so the oversampled subset
varies run to run.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
return {"epoch": epoch, "start_index": start_index}
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
frame_weights,
*,
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
):
"""
Args:
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
dataset_to_indices: Episode end indices.
frame_weights: 1-D sequence/tensor of non-negative weights, one per
dataset frame (length == total dataset frames). Higher weight ⇒
that frame is sampled more often.
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
frame filtering is applied first, then weighting is restricted
to the surviving frames.
"""
super().__init__(
dataset_from_indices,
dataset_to_indices,
episode_indices_to_use=episode_indices_to_use,
drop_n_first_frames=drop_n_first_frames,
drop_n_last_frames=drop_n_last_frames,
shuffle=False,
)
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
idx = torch.tensor(self.indices, dtype=torch.long)
if weights.numel() <= int(idx.max()):
raise ValueError(
f"frame_weights has {weights.numel()} entries but the sampler "
f"references frame index {int(idx.max())}."
)
selected = weights[idx]
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
raise ValueError("frame_weights must be finite and non-negative.")
if float(selected.sum()) <= 0.0:
# All surviving frames have zero weight — fall back to uniform.
selected = torch.ones_like(selected)
self._weights = selected
def __iter__(self) -> Iterator[int]:
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
for i in picks.tolist():
yield self.indices[i]
+17 -10
View File
@@ -366,17 +366,24 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
hub_versions = get_repo_versions(repo_id)
if not hub_versions:
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
msg = (
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
f"either doesn't exist on the Hub yet, or it was uploaded "
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
f" from huggingface_hub import HfApi\n"
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
)
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
# ``__init__`` indexes ``response.headers`` unconditionally on
# current ``huggingface_hub`` versions. Constructing it without
# a real ``Response`` object crashes with either
# ``TypeError: missing 1 required keyword-only argument`` (old
# builds) or ``AttributeError: 'NoneType' object has no attribute
# 'headers'`` (new builds). Skip that path entirely — this isn't
# really an HTTP error, it's a configuration issue — and raise a
# plain ``RuntimeError`` so the message actually reaches the
# caller.
raise RuntimeError(msg)
if target_version in hub_versions:
return f"v{target_version}"
+1 -21
View File
@@ -481,10 +481,8 @@ def reencode_video(
encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
start_time_s: float | None = None,
end_time_s: float | None = None,
) -> None:
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
"""Re-encode a video file using the given encoder configuration.
Args:
input_video_path: Existing video file to read.
@@ -493,17 +491,10 @@ def reencode_video(
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
start_time_s: When set, trim the output to start at this timestamp (seconds).
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
"""
camera_encoder = camera_encoder or camera_encoder_defaults()
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
@@ -535,10 +526,6 @@ def reencode_video(
width = int(in_stream.width)
height = int(in_stream.height)
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
if start_time_s is not None:
src.seek(int(start_time_s * av.time_base), backward=True)
with av.open(
tmp_output_video_path,
mode="w",
@@ -552,14 +539,7 @@ def reencode_video(
out_stream.height = height
for frame in src.decode(in_stream):
frame_time_s = frame.time
if start_time_s is not None and frame_time_s < start_time_s:
continue
if end_time_s is not None and frame_time_s >= end_time_s:
break
frame = frame.reformat(width=width, height=height, format=pix_fmt)
if start_time_s is not None:
frame.pts = None # reset timestamps so the trimmed output starts at t=0
packet = out_stream.encode(frame)
if packet:
dst.mux(packet)
+13 -10
View File
@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
# Dimensions for the flat action/state vectors used by the LeRobot wrapper.
# These correspond to the PandaOmron robot in RoboCasa365.
OBS_STATE_DIM = 16 # base_pos(3) + base_quat(4) + ee_pos_rel(3) + ee_quat_rel(4) + gripper_qpos(2)
ACTION_DIM = 12 # base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
OBS_STATE_DIM = 16 # ee_pos_rel(3) + ee_quat_rel(4) + base_pos(3) + base_quat(4) + gripper_qpos(2)
ACTION_DIM = 12 # ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
@@ -101,14 +101,15 @@ def _resolve_tasks(task: str) -> tuple[list[str], str | None]:
def convert_action(flat_action: np.ndarray) -> dict[str, Any]:
"""Split a flat (12,) action vector into a RoboCasa action dict.
Layout: base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
Layout (openpi / robocasa.utils.env_utils.convert_action order):
ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
"""
return {
"action.base_motion": flat_action[0:4],
"action.control_mode": flat_action[4:5],
"action.end_effector_position": flat_action[5:8],
"action.end_effector_rotation": flat_action[8:11],
"action.gripper_close": flat_action[11:12],
"action.end_effector_position": flat_action[0:3],
"action.end_effector_rotation": flat_action[3:6],
"action.gripper_close": flat_action[6:7],
"action.base_motion": flat_action[7:11],
"action.control_mode": flat_action[11:12],
}
@@ -230,12 +231,14 @@ class RoboCasaEnv(gym.Env):
return {"pixels": images}
# `state.*` keys come from PandaOmronKeyConverter inside the wrapper.
# openpi state order: ee first, then base, then gripper (matches the
# openpi robocasa pipeline / examples/robocasa/main.py state layout).
agent_pos = np.concatenate(
[
raw_obs.get("state.base_position", np.zeros(3)),
raw_obs.get("state.base_rotation", np.zeros(4)),
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
raw_obs.get("state.base_position", np.zeros(3)),
raw_obs.get("state.base_rotation", np.zeros(4)),
raw_obs.get("state.gripper_qpos", np.zeros(2)),
],
axis=-1,
+2
View File
@@ -104,6 +104,8 @@ class AdamWConfig(OptimizerConfig):
eps: float = 1e-8
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
foreach: bool | None = None
fused: bool | None = None
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
kwargs = asdict(self)
+2
View File
@@ -25,6 +25,7 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pi052.configuration_pi052 import PI052Config as PI052Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -49,6 +50,7 @@ __all__ = [
"PI0Config",
"PI0FastConfig",
"PI05Config",
"PI052Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+128 -2
View File
@@ -63,6 +63,79 @@ from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
def _restore_pi052_pretrained_state(
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
pretrained_path: str,
) -> None:
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
pi052's preprocessor includes steps whose constructor args don't
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
fitted-tokenizer path that may not exist at eval time). We rebuild
those pipelines fresh from ``config.recipe_path`` and then walk
over the saved ``policy_{pre,post}processor.json`` files to find
each step's ``state_file`` reference and load the bytes back into
the corresponding fresh step. Today that's only the
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
state quantile stats), but the loop is generic so any future
stateful step picks up its blob automatically.
Pairing is by ``registry_name`` AND position so a benign reorder
on the saved side surfaces a warning rather than silently feeding
the wrong tensors into the wrong step.
"""
import json # noqa: PLC0415
import logging # noqa: PLC0415
from pathlib import Path # noqa: PLC0415
from safetensors.torch import load_file # noqa: PLC0415
base = Path(pretrained_path)
if not base.exists():
return
log = logging.getLogger(__name__)
for pipeline, config_filename in [
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
]:
config_path = base / config_filename
if not config_path.exists():
continue
saved = json.loads(config_path.read_text())
for idx, (saved_step, fresh_step) in enumerate(
zip(saved.get("steps", []), pipeline.steps, strict=False)
):
state_file = saved_step.get("state_file")
if not state_file:
continue
saved_name = saved_step.get("registry_name")
fresh_name = getattr(type(fresh_step), "_registry_name", None)
if saved_name and fresh_name and saved_name != fresh_name:
log.warning(
"PI052 state restore: %s step %d registry name mismatch "
"(saved=%s, fresh=%s); skipping %s",
config_filename, idx, saved_name, fresh_name, state_file,
)
continue
state_path = base / state_file
if not state_path.exists():
log.warning(
"PI052 state restore: %s missing at %s; %s left at fresh init",
state_file, base, fresh_name,
)
continue
fresh_step.load_state_dict(load_file(str(state_path)))
log.info(
"PI052 state restore: loaded %s into %s (step %d)",
state_file, fresh_name, idx,
)
def _reconnect_relative_absolute_steps(
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
) -> None:
@@ -130,6 +203,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "pi052":
from .pi052.modeling_pi052 import PI052Policy
return PI052Policy
elif name == "gaussian_actor":
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
@@ -178,8 +255,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x", "molmoact2".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
"pi052", "gaussian_actor", "smolvla", "wall_x", "molmoact2".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -202,6 +279,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "pi052":
from .pi052.configuration_pi052 import PI052Config
return PI052Config(**kwargs)
elif policy_type == "gaussian_actor":
return GaussianActorConfig(**kwargs)
elif policy_type == "smolvla":
@@ -246,6 +327,12 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
# Optional: HF Hub repo id of the dataset the policy is being
# trained on. Used by policies that auto-fit pieces of their
# preprocessing (e.g. pi052's FAST action tokenizer per
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
# policies fall back to their universal pre-fitted tokenizers.
dataset_repo_id: str | None
dataset_meta: Any | None
@@ -279,6 +366,29 @@ def make_pre_post_processors(
NotImplementedError: If a processor factory is not implemented for the given
policy configuration type.
"""
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
# pi052 pipelines don't roundtrip through the saved
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
# FAST tokenizer path. Generic ``from_pretrained`` then dies
# with ``RenderMessagesStep.__init__() missing 1 required
# positional argument: 'recipe'`` (job 22164494).
#
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
# fresh from ``config.recipe_path`` and transplant the saved
# stateful blobs (normalizer stats) from the checkpoint dir.
from .pi052.processor_pi052 import make_pi052_pre_post_processors
preprocessor, postprocessor = make_pi052_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_repo_id=kwargs.get("dataset_repo_id"),
)
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
@@ -373,6 +483,22 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif policy_cfg.type == "pi052":
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
# come before the PI05Config isinstance check below (otherwise
# pi052 would silently pick up π0.5's processor).
from .pi052.processor_pi052 import make_pi052_pre_post_processors
processors = make_pi052_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
# enabled — the train loop sets it from ``--dataset.repo_id``.
# When ``None``, ``make_pi052_pre_post_processors`` skips
# the auto-fit and uses the universal tokenizer.
dataset_repo_id=kwargs.get("dataset_repo_id"),
)
elif isinstance(policy_cfg, PI05Config):
from .pi05.processor_pi05 import make_pi05_pre_post_processors
-1
View File
@@ -178,7 +178,6 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
+42
View File
@@ -0,0 +1,42 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
inference recipe on lerobot.
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
(the "high-level subtask prediction" of the paper, §IV.D),
* AR text generation at inference (:meth:`PI052Policy.select_message`),
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
text head against missing context at inference.
See ``src/lerobot/configs/recipes/subtasks_vqa.yaml`` for the
canonical training recipe and
``examples/training/pi052_hirobot.slurm`` for the launcher.
"""
from .configuration_pi052 import PI052Config
from .modeling_pi052 import PI052Policy
from .processor_pi052 import make_pi052_pre_post_processors
from .text_processor_pi052 import PI052TextTokenizerStep
__all__ = [
"PI052Config",
"PI052Policy",
"PI052TextTokenizerStep",
"make_pi052_pre_post_processors",
]
@@ -0,0 +1,235 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
hierarchical inference recipe.
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
~300M Gemma action expert, joint training with FAST tokens during
pre-train and flow matching during post-train), but with the
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
to predict both:
* **subtask strings** at the high level (cross-entropy on the LM
head), and
* **action chunks** at the low level (flow matching on the
action-expert tokens).
This is the dual-head co-training pattern from the paper:
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, )‖²
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
inference into a text-prediction step followed by an action-prediction
step, which the multi-rate ``PI052Runtime`` (in
``lerobot.policies.pi052.inference``) drives at separate rates.
"""
from dataclasses import dataclass
from lerobot.configs import PreTrainedConfig
from lerobot.optim.optimizers import AdamWConfig
from ..pi05.configuration_pi05 import PI05Config
@PreTrainedConfig.register_subclass("pi052")
@dataclass
class PI052Config(PI05Config):
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
Recipe-driven dual-head training: the flow head supervises actions,
the LM head supervises subtask / plan / memory / VQA text. The
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
"""
# Recipe / language stack ---------------------------------------------
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
shipped alongside this policy. Set to ``None`` to disable recipe
rendering and fall back to π0.5's single-task ``Task: ... Action:``
prompt path (unannotated datasets keep working that way)."""
apply_chat_template: bool = False
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
chat template, so we don't apply one. The recipe renderer's output
is concatenated as a plain prefix + assistant suffix instead,
mirroring how the π0.5 paper's high-level inference samples text
auto-regressively after the prefix."""
# Loss weights --------------------------------------------------------
# Paper §IV.D uses α=10 between the flow and text terms, assuming
# text is a rare auxiliary task. With the recipe stack the flow-only
# `low_level` branch fires on a large share of samples, so α=10
# swamps the LM head and collapses generation into degenerate
# repetition. We use the milder 5:1 split here.
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / π0.5 behaviour)."""
flow_loss_weight: float = 5.0
"""Weight on the action-expert flow-matching term. ``5.0`` — a milder
flow:text split than the paper's α=10, since the flow-only
``low_level`` recipe already gives the action expert frequent
gradient. Lower it further if the LM head still underfits."""
# Backbone training ---------------------------------------------------
unfreeze_lm_head: bool = True
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
The existing ``PI05Policy`` zeroes / freezes the head on load
because it never reads from it. Must be ``True`` for π0.5-style
hierarchical inference."""
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
# Randomly drop non-target context messages so the LM head learns
# to handle missing /
# stale plan / memory at inference. Defaults to 0.0 so behaviour
# is identical until explicitly enabled.
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
# FAST discrete-action supervision — paper §III.B-C ------------------
# When enabled, actions are *also* tokenised via the FAST tokenizer
# ("physical-intelligence/fast") and supervised with cross-entropy
# on the PaliGemma LM head — exactly as in the paper's pre-training
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
# ActionTokenizerProcessorStep is wired into the preprocessor
# pipeline when this flag is set; the loss is computed in
# PI052Policy.forward.
enable_fast_action_loss: bool = True
"""If True, tokenise actions with the FAST tokenizer and add a
cross-entropy loss on the LM head. On by default to match the
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
§III.B-C Eq. 1). Set to False if you only want the
post-training-style flow + text recipe."""
action_tokenizer_name: str = "physical-intelligence/fast"
"""HF identifier for the FAST action tokenizer."""
max_action_tokens: int = 256
"""Maximum number of FAST tokens per action chunk."""
fast_skip_tokens: int = 128
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
collisions with PaliGemma's text vocabulary."""
fast_action_loss_weight: float = 1.0
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
auto_fit_fast_tokenizer: bool = False
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
base_tokenizer_name, fit_samples)``. On cache miss, it loads
``action_tokenizer_name`` as a base, samples
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
``.fit()``, saves the result, and uses *that* fitted path as the
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
explicitly recommend per-dataset fitting for best compression.
Off by default because the fit requires a separate pre-training
pass over the dataset (~1-2 min on a medium dataset) and depends
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
when you want paper-faithful compression; leave off to fall back
on the universal ``physical-intelligence/fast`` codebook."""
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
fast_tokenizer_fit_samples: int = 1024
"""Number of action chunks to sample for the fit. The FAST paper uses
a few thousand; 1024 is a reasonable default for medium datasets."""
# Knowledge insulation — paper §III.B --------------------------------
# When enabled, gradients from the action expert's flow loss are
# blocked from flowing back into the VLM's K/V projections. This
# prevents the action loss from over-fitting the language backbone
# to robot-specific features. Implemented in ``modeling_pi052`` as
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
# that splits queries into VLM and action halves and ``.detach()``-s
# the VLM K/V tensors used in the action-half's attention.
knowledge_insulation: bool = False
"""If True, route every transformer layer through the KI
attention path that blocks action→VLM gradient flow on K/V."""
# Learning-rate defaults --------------------------------------------
# pi052 inherits π0.5's openpi-validated optimizer config (peak LR
# 2.5e-5, cosine→2.5e-6, 1k warmup, AdamW (0.9, 0.95), wd=0.01,
# grad_clip=1.0). The only place pi052 needs to diverge from pi05
# is the LM-head LR multiplier: pi05 has no text supervision so the
# head doesn't get gradients; pi052 always has text supervision
# (subtask / memory / VQA) via the recipe, and under KI the LM head
# only sees gradients on ~3045% of the batch (the text-CE mask
# share of the recipe). Under aggressive cosine decay this is too
# weak to keep the head pinned, so it drifts back toward PaliGemma's
# pretrained ``<loc>`` first-token bias. 5x is the documented fix
# (see ``PI05Config.lm_head_lr_scale`` docstring); the wiring is
# already in ``PI05Policy.get_optim_params`` — it splits the LM head
# + tied ``embed_tokens`` into their own param group while sharing
# the same cosine lambda, so the 5x ratio is preserved across decay.
lm_head_lr_scale: float = 5.0
# PaLM-style z-loss on text CE. Penalises the log-partition function
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
# while CE stays low, because a uniform additive logit bias cancels in
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
# commonly cited weight; set 0 to disable entirely.
text_ce_z_loss_weight: float = 1e-4
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
# unconditionally at model build time — see ``_enable_hf_kernels``
# in ``modeling_pi052``. The patch is process-global, idempotent
# and degrades gracefully if ``liger-kernel`` is missing. Measured
# at -4.5% step time on H100 (bench job 22161421); peak memory
# unchanged. ``fused_linear_cross_entropy`` ships separately via
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
use_hf_kernels: bool = True
"""Deprecated. Liger HF kernels are patched unconditionally by
``_enable_hf_kernels`` — this field is retained as a no-op for
backward compatibility with checkpoints saved before commit
d70c8104 (which still serialize ``use_hf_kernels: true`` into
``config.json``). Loading those configs would otherwise raise
``DecodingError: The fields use_hf_kernels are not valid for
PI052Config`` (job 22164492). Remove in a future major bump."""
# Optimizer foreach/fused. pi052 carries these locally because the shared
# PI05Config (kept identical to upstream main) does not define them; the
# checkpoints we train serialize both keys into config.json, so they must
# be valid PI052Config fields and flow into the AdamW preset below.
optimizer_foreach: bool | None = False
optimizer_fused: bool | None = True
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
foreach=self.optimizer_foreach,
fused=self.optimizer_fused,
)
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through the text head when
# we're training it. Override the π0.5 default
# (``train_expert_only=True``) unless the user explicitly opts
# out of text training via ``text_loss_weight=0``.
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
self.train_expert_only = False
@@ -0,0 +1,304 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset-specific FAST action tokenizer fitting.
The published ``physical-intelligence/fast`` tokenizer is a *universal*
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
π0.5 itself, the recommended practice is to **finetune the tokenizer on
your specific dataset's action distribution** before training the
policy — same way one would adapt a language tokenizer to a domain
corpus. Without this finetune step, action sequences from your robot
may require more tokens per chunk than necessary, lowering effective
compression and slowing convergence of the action-CE loss.
This module provides a single utility, :func:`fit_fast_tokenizer`,
that does the finetune. The training entry point invokes it
automatically when the policy's ``enable_fast_action_loss`` and
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
The fitted tokenizer is saved to
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
runs over the same dataset re-use it.
"""
from __future__ import annotations
import hashlib
import logging
import os
import time
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
# that's the image / feature-extractor convention). Centralised here so
# the cache-hit check and the rank-N readiness wait agree on the same
# sentinel.
_CACHE_SENTINEL = "processor_config.json"
def _dataset_signature(
dataset_repo_id: str,
base_tokenizer_name: str,
n_samples: int,
chunk_size: int,
) -> str:
"""Deterministic short hash for naming the cache directory.
Keys on (dataset, base tokenizer, sample count, chunk size) so any
of those changing re-runs the fit. ``chunk_size`` matters because
the tokenizer is fit on chunks of that length.
"""
h = hashlib.sha256()
h.update(dataset_repo_id.encode("utf-8"))
h.update(b"\0")
h.update(base_tokenizer_name.encode("utf-8"))
h.update(b"\0")
h.update(str(n_samples).encode("utf-8"))
h.update(b"\0")
h.update(str(chunk_size).encode("utf-8"))
return h.hexdigest()[:16]
def fit_fast_tokenizer(
*,
dataset_repo_id: str,
cache_dir: str | Path,
base_tokenizer_name: str = "physical-intelligence/fast",
n_samples: int = 1024,
chunk_size: int = 50,
seed: int = 42,
) -> str:
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
Args:
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
cache_dir: Directory under which to save (and look up) fitted
tokenizers. The actual save path is
``{cache_dir}/{signature}``.
base_tokenizer_name: HF identifier for the base FAST tokenizer
to finetune from. ``physical-intelligence/fast`` is the
universal one.
n_samples: Number of action chunks to sample for the fit. The
FAST paper uses a few thousand; ``1024`` is a good default
for medium datasets.
chunk_size: Length of each action chunk (matches
``policy.chunk_size``). The FAST tokenizer is fit on
sequences of this length.
seed: RNG seed for sample selection.
Returns:
The local path to the fitted tokenizer. Passed directly to
``--policy.action_tokenizer_name`` for the training run.
Raises:
ImportError: If the ``transformers`` library doesn't expose
``AutoProcessor`` or the FAST tokenizer doesn't have a
``.fit()`` method (then you're on an older FAST snapshot —
update to the current published model).
FileNotFoundError: If the dataset can't be loaded.
"""
cache_dir = Path(cache_dir)
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
out_dir = cache_dir / sig
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
logger.info(
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
"dataset=%s base=%s n_samples=%d",
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
)
return str(out_dir)
# DDP-safe fit: only the (local) main process actually fits + saves;
# other ranks poll the cache sentinel until the leader is done.
# Without this guard, all N ranks fit concurrently and race on
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
# and compiles a ``.pyc`` — concurrent writers occasionally produce
# a stale / partial ``.pyc`` and the subsequent ``from .. import
# UniversalActionProcessor`` raises ``AttributeError``.
is_leader = (
int(os.environ.get("RANK", "0")) == 0
and int(os.environ.get("LOCAL_RANK", "0")) == 0
)
if not is_leader:
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
start = time.monotonic()
while not (out_dir / _CACHE_SENTINEL).exists():
if time.monotonic() - start > timeout_s:
raise RuntimeError(
f"FAST tokenizer fit: non-leader rank timed out after "
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
"Leader rank likely crashed during the fit."
)
time.sleep(2.0)
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
return str(out_dir)
logger.info(
"FAST tokenizer cache miss — fitting on dataset=%s "
"base=%s n_samples=%d chunk_size=%d%s",
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
)
from transformers import AutoProcessor # noqa: PLC0415
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
# Stream a single episode's worth of action chunks at a time so
# we don't blow memory on huge datasets. Random episode +
# random start offset gives a reasonable spread.
#
# Actions are read straight from the underlying HF dataset's
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
# training item (delta-timestamp expansion + video decode + image
# transforms); a single bad video frame would then throw and, since
# the failure was swallowed at debug level, silently starve the fit
# of every chunk. The action column carries no video, so reading it
# directly is both faster and immune to decode errors.
rng = np.random.default_rng(seed)
actions_buf: list[np.ndarray] = []
# Resolve the dataset's data parquet shards directly, sidestepping
# ``LeRobotDataset(repo_id, episodes=[N])`` which on v3-format
# datasets routes through HF datasets'' split lookup and raises
# ``ValueError: Instruction "train" corresponds to no data!`` for
# every episode (job 22182985 looped through 13,293 skipped episodes
# for ~2.5 h before NCCL killed it). Reading the ``action`` column
# straight from the parquet shards is also faster: each per-episode
# ``LeRobotDataset`` instantiation re-parses every meta file.
from huggingface_hub import snapshot_download # noqa: PLC0415
import pyarrow as _pa # noqa: PLC0415
import pyarrow.parquet as _pq # noqa: PLC0415
snap = Path(snapshot_download(repo_id=dataset_repo_id, repo_type="dataset"))
data_files = sorted((snap / "data").glob("chunk-*/file-*.parquet"))
if not data_files:
raise RuntimeError(
f"FAST fit: no ``data/chunk-*/file-*.parquet`` shards found under {snap!s}."
)
# Read just the (episode_index, action) columns once across all
# shards. This is the same pattern used elsewhere in the codebase
# for whole-dataset audits and stays under ~2 GB even on 32 k-episode
# / 29 M-frame datasets because the action column is a fixed-length
# float vector.
tables = [_pq.read_table(f, columns=["episode_index", "action"]) for f in data_files]
table = _pa.concat_tables(tables)
eps = table["episode_index"].to_numpy()
acts_col = table["action"]
# ``action`` may be a fixed-shape ListArray or a 2-D NumericArray;
# ``to_numpy(zero_copy_only=False)`` produces an object array of
# 1-D NumPy actions either way, which we stack into (N, D).
try:
acts = np.stack(acts_col.to_numpy(zero_copy_only=False)).astype(np.float32)
except Exception: # noqa: BLE001
# Fallback path for nested-list types: flatten via to_pylist().
acts = np.asarray(acts_col.to_pylist(), dtype=np.float32)
if acts.ndim != 2:
raise RuntimeError(
f"FAST fit: expected ``action`` rows to be 1-D vectors; got shape {acts.shape}."
)
# Episode index → slice (start, stop) into ``acts`` along axis 0.
# ``eps`` is monotonically increasing within each parquet shard but
# we make no assumption across shards — sort once and group.
order = np.argsort(eps, kind="stable")
eps_sorted = eps[order]
boundaries = np.searchsorted(eps_sorted, np.arange(int(eps_sorted.max()) + 2))
ep_to_slice: dict[int, tuple[int, int]] = {
int(ep): (int(boundaries[ep]), int(boundaries[ep + 1]))
for ep in range(len(boundaries) - 1)
if boundaries[ep] < boundaries[ep + 1]
}
num_episodes = len(ep_to_slice)
# ``acts`` is in original (un-sorted-by-episode) row order; reorder
# so per-episode slices are contiguous.
acts = acts[order]
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
collected = 0
eps_visited = 0
short_episodes = 0
ep_indices = list(ep_to_slice.keys())
for ep_idx in rng.permutation(ep_indices):
if collected >= n_samples:
break
start, stop = ep_to_slice[int(ep_idx)]
ep_actions = acts[start:stop]
if ep_actions.shape[0] < chunk_size:
short_episodes += 1
continue
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
for s in starts:
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
collected += 1
if collected >= n_samples:
break
eps_visited += 1
if not actions_buf:
raise RuntimeError(
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
f"all {num_episodes} episodes were shorter than chunk_size="
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
"``action`` column. Lower ``chunk_size`` to match your episode "
"lengths."
)
actions = np.stack(actions_buf, axis=0).astype(np.float32) # (N, H, D)
logger.info(
"FAST fit: collected %d chunks of shape %s from %d episodes",
actions.shape[0], actions.shape[1:], eps_visited,
)
# Quantile-normalise per dimension before fitting.
#
# The FAST tokenizer DCT-transforms actions, scales by ``scale`` and
# rounds to integer tokens; the integer *range* must fit the
# codebook (vocab_size, default 1024). Raw motor units (e.g. encoder
# ticks) blow that range up — hence "Vocab size 1024 is too small".
# More importantly, at training time ``ActionTokenizerProcessorStep``
# runs *after* the QUANTILES ``NormalizerProcessorStep``, so it
# encodes normalised actions. Fitting on raw actions would mismatch
# that space. We replicate QUANTILES normalisation here (per-dim
# [q01, q99] → [-1, 1], clipped) so the fit and the training-time
# encode see the same distribution.
flat = actions.reshape(-1, actions.shape[-1])
q01 = np.quantile(flat, 0.01, axis=0)
q99 = np.quantile(flat, 0.99, axis=0)
span = np.where((q99 - q01) > 1e-6, q99 - q01, 1.0)
actions = np.clip((actions - q01) / span * 2.0 - 1.0, -1.0, 1.0).astype(np.float32)
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
if not hasattr(base, "fit"):
raise ImportError(
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
"method — your transformers / model snapshot is too old. Update "
"to the current ``physical-intelligence/fast`` revision."
)
fitted = base.fit(actions)
out_dir.mkdir(parents=True, exist_ok=True)
fitted.save_pretrained(str(out_dir))
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
return str(out_dir)
@@ -0,0 +1,73 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PI052 inference / runtime orchestration.
Multi-rate runtime that mirrors the recipe-time training shape:
low_level_execution → LowLevelForward + DispatchAction (high Hz)
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
memory_update → MemoryUpdateFwd (event: subtask_change)
user_interjection_response → UserInterjectionFwd (event: stdin)
ask_vqa_* → AskVQAFwd (event: stdin question)
speech tool calls → DispatchToolCalls (event: tool_call_pending)
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
``run()``.
"""
from .repl import StdinReader
from .runtime import PI052Runtime
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
UserInterjectionFwd,
)
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
from .ui import make_state_panel, print_robot_lines, print_user_line
__all__ = [
# runtime
"PI052Runtime",
"StdinReader",
# state helpers
"initial_runtime_state",
"push_log",
"set_if_changed",
"take_event",
# triggers
"Trigger",
"Tick",
"TickClock",
"HzTrigger",
"EventTrigger",
# steps
"InferenceStep",
"LowLevelForward",
"DispatchAction",
"HighLevelSubtaskFwd",
"MemoryUpdateFwd",
"UserInterjectionFwd",
"AskVQAFwd",
"DispatchToolCalls",
# UI
"make_state_panel",
"print_robot_lines",
"print_user_line",
]
@@ -0,0 +1,105 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stdin REPL event collector for the PI052 runtime.
Reads non-blocking stdin lines, classifies each one heuristically:
"stop" / "quit" / "exit" → state["stop"] = True
"/action" / "/pause" → set state["mode"]
ends with "?" → user_vqa_query event
starts with "task:" or first line → set runtime task
anything else → user_interjection event
Plugged into the runtime via ``event_collector=StdinReader().poll``.
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
directly in its REPL / autonomous loops and does *not* wire this
collector; it's kept as the documented embedding hook and for tests.
"""
from __future__ import annotations
import select
import sys
from dataclasses import dataclass, field
from typing import Any
@dataclass
class StdinReader:
"""Non-blocking stdin line collector for the runtime loop."""
prompt: str = "> "
_seen_first_line: bool = field(default=False, init=False)
_prompted: bool = field(default=False, init=False)
def poll(self, state: dict[str, Any]) -> None:
"""Drain pending stdin lines into runtime events."""
# Print the input prompt once on every fresh tick if we don't
# already have a pending line; matches the expected REPL feel.
if not self._prompted:
print(self.prompt, end="", flush=True)
self._prompted = True
# ``select`` with timeout=0 makes this non-blocking. Only works
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
try:
ready, _, _ = select.select([sys.stdin], [], [], 0)
except (ValueError, OSError):
return
if not ready:
return
line = sys.stdin.readline()
if not line: # EOF
state["stop"] = True
return
line = line.strip()
self._prompted = False # we'll re-prompt next tick
if not line:
return
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
state["stop"] = True
return
# Slash commands flip the run mode. ``/pause`` stops the action
# loop (the action steps gate on ``state["mode"]``); ``/action``
# resumes it.
if lower.split(" ", 1)[0] in {"/action", "/act", "/run"}:
state["mode"] = "action"
return
if lower in {"/pause", "/p"}:
state["mode"] = "paused"
queue = state.get("action_queue")
if hasattr(queue, "clear"):
queue.clear()
return
# First non-control line sets the task if no task is active.
if not state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
state["task"] = task
print(f"[pi052] Task: {task}", flush=True)
self._seen_first_line = True
return
# Question → VQA; statement → interjection.
if lower.endswith("?"):
state["recent_vqa_query"] = line
state.setdefault("events_this_tick", []).append("user_vqa_query")
else:
state["recent_interjection"] = line
state.setdefault("events_this_tick", []).append("user_interjection")
@@ -0,0 +1,205 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PI052 runtime loop.
Threads the multi-rate inference pipeline together with a stdin REPL
event collector, drives ticks through :class:`TickClock`, and prints
state-change updates to the user.
"""
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable
from .runtime_state import initial_runtime_state, push_log
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
)
from .triggers import EventTrigger, HzTrigger, TickClock
logger = logging.getLogger(__name__)
@dataclass
class PI052Runtime:
"""Compose the inference pipeline and drive it tick-by-tick."""
policy: Any
tools: dict[str, Any] = field(default_factory=dict)
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
from :func:`lerobot.tools.get_tools(meta)` when wiring the
runtime."""
observation_provider: Callable[[], dict | None] | None = None
"""Closure returning the current preprocessed observation batch.
``None`` for dry-run / language-only sessions."""
robot_executor: Callable[[Any], None] | None = None
"""Closure that takes one action chunk and forwards it to the
robot. ``None`` for dry-run."""
event_collector: Callable[[dict], None] | None = None
"""Per-tick hook that polls external sources (stdin, network) and
appends event names to ``state["events_this_tick"]``."""
chunk_hz: float = 4.0
ctrl_hz: float = 50.0
high_level_hz: float = 1.0
max_rate_hz: float = 50.0
pipeline: list[InferenceStep] = field(init=False)
state: dict[str, Any] = field(init=False)
_stop: bool = field(default=False, init=False)
def __post_init__(self) -> None:
# Subtask + memory + VQA configuration. Pipeline:
#
# HighLevelSubtaskFwd → generate the next subtask via the LM
# head at ~``high_level_hz``; writes
# ``current_subtask`` and emits
# ``subtask_change`` on a transition.
# MemoryUpdateFwd → on ``subtask_change``, refresh
# ``current_memory`` from the
# ``memory_update`` head.
# AskVQAFwd → answer camera-grounded stdin questions.
# LowLevelForward → action chunk conditioned on the
# generated ``current_subtask``.
# DispatchAction → drain the chunk to the robot.
# DispatchToolCalls → fire any pending tool calls.
#
# Order matters: ``HighLevelSubtaskFwd`` must run before
# ``MemoryUpdateFwd`` so the event is visible the same tick, and
# both must run before ``LowLevelForward`` (which is gated on
# "action queue empty") so the chunk consumes the freshest
# subtask. ``UserInterjectionFwd`` is still importable but
# disabled until plan generation is wired in.
self.pipeline = [
HighLevelSubtaskFwd(
trigger=HzTrigger(self.high_level_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
# Listens for the ``subtask_change`` event raised by
# ``HighLevelSubtaskFwd`` and refreshes ``current_memory``.
MemoryUpdateFwd(
trigger=EventTrigger("subtask_change"),
policy=self.policy,
observation_provider=self.observation_provider,
),
AskVQAFwd(
policy=self.policy,
observation_provider=self.observation_provider,
),
LowLevelForward(
trigger=HzTrigger(self.chunk_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
DispatchAction(
trigger=HzTrigger(self.ctrl_hz),
robot_executor=self.robot_executor,
),
DispatchToolCalls(tools=self.tools),
]
self.state = initial_runtime_state()
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def set_task(self, task: str) -> None:
"""Set or replace the active task. Logged for the REPL."""
self.state["task"] = task
push_log(self.state, f"Task: {task}")
def stop(self) -> None:
self._stop = True
def run(self, *, max_ticks: int | None = None) -> None:
"""Main loop. Returns when ``stop()`` is called or after
``max_ticks`` ticks (useful for tests / dry-run)."""
clock = TickClock(max_rate_hz=self.max_rate_hz)
while not self._stop:
tick = clock.advance()
self.state["_tick"] = tick
self.state["events_this_tick"] = []
self.state["log_lines"] = []
if self.event_collector is not None:
self.event_collector(self.state)
if self.state.get("stop"):
self._stop = True
break
for step in self.pipeline:
self.state = step(self.state)
self._flush_logs()
if max_ticks is not None and tick.index >= max_ticks:
break
self._on_shutdown()
# ------------------------------------------------------------------
# REPL helper: drive one full pipeline pass and return its logs
# ------------------------------------------------------------------
def step_once(self) -> list[str]:
"""Run one tick of the pipeline and return the log lines.
Used by the interactive REPL: instead of a background thread,
the CLI drives ticks synchronously after each user input. Logs
are returned (not printed) so the caller can route them into
the rich-Live chat scrollback.
"""
from .triggers import Tick # noqa: PLC0415
# Synthesize a tick. We don't need the real wall-clock pacing
# here — the REPL drives the runtime, not vice versa — but
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
# bump it generously so every Hz-triggered step considers
# itself due.
import time as _time # noqa: PLC0415
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
self.state["log_lines"] = []
# ``events_this_tick`` is set up by the caller before
# ``step_once`` (the REPL pushes user-driven events first).
self.state.setdefault("events_this_tick", [])
for step in self.pipeline:
self.state = step(self.state)
return list(self.state.get("log_lines") or [])
# ------------------------------------------------------------------
# I/O
# ------------------------------------------------------------------
def _flush_logs(self) -> None:
for line in self.state.get("log_lines") or []:
print(f"[pi052] {line}", flush=True)
def _on_shutdown(self) -> None:
# Drain any queued action chunks safely.
queue = self.state.get("action_queue")
if isinstance(queue, deque):
queue.clear()
print("[pi052] runtime stopped", flush=True)
@@ -0,0 +1,95 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runtime state passed between inference steps each tick.
The runtime threads a single dict through the pipeline; this module
documents the shape and provides factories. We use a plain ``dict``
rather than a frozen dataclass because steps freely add and remove
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
…) and dataclass field churn would just get in the way.
Stable keys (read by multiple steps):
task str the current top-level task
current_plan str | None latest plan emitted by the planner
current_subtask str | None latest subtask the policy is executing
current_memory str | None latest compressed memory
recent_interjection str | None most recent user interjection text (consumed)
action_queue collections.deque[Tensor] pending action chunks
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
events_this_tick list[str] triggers consumed this tick
_tick Tick current tick (set by the loop)
mode str "action" (run the robot) | "paused"
(action loop stopped — robot holds)
log_lines list[str] human-readable status lines printed each tick
"""
from __future__ import annotations
from collections import deque
from typing import Any
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
"""Build a fresh runtime state dict with sensible defaults."""
return {
"task": task,
"current_plan": None,
"current_subtask": None,
"current_memory": None,
"recent_interjection": None,
"action_queue": deque(),
"tool_calls_pending": [],
"events_this_tick": [],
"log_lines": [],
"mode": "action",
"stop": False,
}
def take_event(state: dict[str, Any], event_name: str) -> bool:
"""Pop ``event_name`` from ``events_this_tick`` if present.
Steps that consume an event call this so the same event doesn't
re-fire on a sibling step within the same tick.
"""
events: list[str] = state.get("events_this_tick") or []
if event_name in events:
events.remove(event_name)
return True
return False
def push_log(state: dict[str, Any], line: str) -> None:
"""Append ``line`` to the per-tick log buffer; the runtime prints
it at the end of the tick."""
state.setdefault("log_lines", []).append(line)
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
"""Update ``state[key]`` and log a diff line if the value changed.
Returns ``True`` if the value actually changed.
"""
prev = state.get(key)
if prev == value:
return False
state[key] = value
if label is not None:
push_log(state, f" {label}: {value}")
return True
@@ -0,0 +1,955 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference steps for the PI052 multi-rate runtime.
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
the runtime applies them in order each tick. When a step's trigger
doesn't fire, the step is a no-op and the runtime moves on.
Stream-to-step mapping mirrors the ``subtasks_vqa.yaml`` recipe:
* ``LowLevelForward`` — calls ``policy.select_action`` for the
action chunk; trained by
``low_level_execution``
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
* ``DispatchAction`` — pops one action per control tick and
forwards to the robot
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
next subtask; trained by
``high_level_subtask``
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
``memory_update``
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
``user_interjection_response``
* ``AskVQAFwd`` — fires on stdin question; trained by
``ask_vqa_*``
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
the matching ``Tool`` instance
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from .runtime_state import push_log, set_if_changed, take_event
from .triggers import EventTrigger, HzTrigger, Trigger
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Step base + runner
# ---------------------------------------------------------------------------
@dataclass
class InferenceStep:
"""A trigger-gated callable. Subclasses override :meth:`run`."""
trigger: Trigger
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
if not self.trigger.should_fire(state["_tick"], state):
return state
return self.run(state) or state
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
raise NotImplementedError
# ---------------------------------------------------------------------------
# Low-level (action) path
# ---------------------------------------------------------------------------
@dataclass
class LowLevelForward(InferenceStep):
"""Run the policy's action head and produce one action chunk."""
policy: Any = None
observation_provider: Any = None
"""Callable ``() -> dict``: returns the current observation batch
(already preprocessed). Typically wraps the robot's camera /
proprio reads. ``None`` in dry-run mode → step skips."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or self.observation_provider is None:
return None
# ``/vlm`` mode pauses the whole action loop so the robot holds
# position while the operator probes the VLM with VQA.
if state.get("mode", "action") != "action":
return None
if not state.get("task"):
return None
# PI052 produces *action chunks* (typically 50 steps via
# flow-matching). Every step gets dispatched to the robot;
# popping one per dispatch tick is essentially free. Only
# generate a new chunk once the previous one has fully
# drained — this is the canonical "sense → think → act"
# loop. Refreshing while a chunk is still queued causes the
# new chunk to "telescope" past the old one (planned from an
# observation that's already 25+ steps stale by the time it
# starts dispatching).
queue = state.setdefault("action_queue", [])
if len(queue) > 0:
return None
observation = self.observation_provider()
if observation is None:
return None
# The action expert is conditioned on the SUBTASK generated by
# the high-level loop (``HighLevelSubtaskFwd`` runs earlier in
# the pipeline and writes ``current_subtask``). Matches the
# training-time ``low_level_execution`` recipe — ``user(${subtask})``.
# Falls back to the task string only on the very first frame,
# before the high-level loop has produced a subtask.
subtask = state.get("current_subtask") or state.get("task") or ""
ctx = [{"role": "user", "content": subtask}]
# ``add_generation_prompt=False`` to match the training-time
# prefix shape: at training the action expert sees the rendered
# user turn ending at ``<|im_end|>`` (no trailing
# ``<|im_start|>assistant\n``). Passing True here would append
# extra role-marker tokens the action expert never saw during
# training.
text_batch = _build_text_batch(self.policy, ctx, add_generation_prompt=False)
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
observation = dict(observation)
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
try:
# ``predict_action_chunk`` returns the *full* chunk shape
# ``(batch, n_action_steps, action_dim)``. Enqueue every
# step so DispatchAction at ctrl_hz can drain them
# smoothly until the next refresh.
chunk = self.policy.predict_action_chunk(observation)
except Exception as exc: # noqa: BLE001
logger.warning(
"predict_action_chunk failed: %s",
exc,
exc_info=logger.isEnabledFor(logging.DEBUG),
)
push_log(
state,
f" [warn] predict_action_chunk failed: "
f"{type(exc).__name__}: {exc}",
)
return None
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
# each step as a ``(1, action_dim)`` tensor so the existing
# action executor's batch-squeeze logic works unchanged.
if chunk.ndim == 3:
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
elif chunk.ndim == 2:
chunk_iter = chunk
else:
chunk_iter = chunk.unsqueeze(0)
for step in chunk_iter:
queue.append(step.unsqueeze(0))
state["last_chunk_size"] = int(chunk_iter.shape[0])
return None
@dataclass
class DispatchAction(InferenceStep):
"""Pop one action per tick and hand it to the robot.
In dry-run mode (``robot_executor=None``) the step still pops the
queue so it doesn't grow unbounded — the popped tensor is logged
instead of executed.
Wall-clock catch-up: the action queue represents an open-loop
trajectory at a fixed step rate (``trigger.hz`` ≈ ``ctrl_hz``).
When the main loop stalls — e.g. an LLM call for the high-level
subtask blocks for ~2 s on MPS — the dispatch trigger fires only
once over that whole interval. Naively popping a single entry per
fire makes the robot lag further and further behind the planned
timeline, and a 50-step chunk would take ~125 s to drain instead
of ~1.7 s. Track real elapsed time between dispatches and pop
``round(elapsed * hz)`` entries, sending the most recent one. The
skipped intermediate joint targets are stale anyway — the dynamixel
will smooth toward the latest goal position.
"""
robot_executor: Any = None
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
_last_dispatch_t: float | None = field(default=None, init=False)
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
import time as _time # noqa: PLC0415
# ``/vlm`` mode pauses dispatch — the robot holds its last
# commanded position while the operator runs VQA.
if state.get("mode", "action") != "action":
self._last_dispatch_t = None
return None
queue = state.get("action_queue")
if not queue:
# Reset wall-clock anchor when the queue is empty so the
# next chunk doesn't see a huge fake "elapsed" window.
self._last_dispatch_t = None
return None
now = _time.monotonic()
hz = getattr(self.trigger, "hz", 30.0)
if self._last_dispatch_t is None or hz <= 0:
n_to_pop = 1
else:
elapsed = now - self._last_dispatch_t
# ``max(1, ...)`` so we always pop at least one when the
# trigger fires; ``min(len(queue), ...)`` so we don't run
# off the end of the chunk.
n_to_pop = max(1, min(len(queue), int(round(elapsed * hz))))
self._last_dispatch_t = now
# Drain ``n_to_pop`` stale entries, keep only the latest as the
# action actually sent. The intermediate joint targets would
# all be ~1030 ms apart in chunk time — the robot can't track
# them individually anyway when the host loop is slow.
latest = None
for _ in range(n_to_pop):
if not queue:
break
latest = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
state["actions_dispatched"] = state.get("actions_dispatched", 0) + 1
if latest is not None and self.robot_executor is not None:
self.robot_executor(latest)
return None
# ---------------------------------------------------------------------------
# High-level (text) paths — all use policy.select_message
# ---------------------------------------------------------------------------
_LOC_TOKENIZER_CACHE: dict[str, Any] = {}
def _get_loc_tokenizer(tok_name: str, auto_tokenizer_cls: Any, register_loc_fn: Any) -> Any:
"""Return a loc-token-registered tokenizer, loading from disk only once.
``AutoTokenizer.from_pretrained`` + loc-token registration is expensive and
the result is immutable, so cache per ``tok_name``.
"""
tokenizer = _LOC_TOKENIZER_CACHE.get(tok_name)
if tokenizer is None:
tokenizer = register_loc_fn(auto_tokenizer_cls.from_pretrained(tok_name))
_LOC_TOKENIZER_CACHE[tok_name] = tokenizer
return tokenizer
def _build_text_batch(
policy: Any,
prompt_messages: list[dict[str, Any]],
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""Tokenize chat messages into the batch ``select_message`` expects.
PI052's backbone (PaliGemma) ships no chat template, so we train on
a plain role-prefixed concatenation built by
``PI052TextTokenizerStep``. We reuse that exact formatter so the
inference prefix matches training; ``add_generation_prompt`` appends
the bare ``Assistant: `` header the LM head continues from.
"""
import torch # noqa: PLC0415
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
_flatten_say_tool_calls,
_format_messages,
_strip_blocks,
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
# Register PaliGemma's <locDDDD> tokens so inference encoding /
# decoding sees them as single vocab ids — must match training.
# The tokenizer is read-only after registration, so cache it: rebuilding it
# from disk on every call dominated eval runtime (this runs twice per env
# per replan — subtask gen + action prompt).
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
prompt, _spans = _format_messages(messages)
if add_generation_prompt:
prompt = prompt + "Assistant: "
encoded = tokenizer(prompt, return_tensors="pt")
ids = encoded["input_ids"]
attn = encoded.get("attention_mask")
if attn is None and tokenizer.pad_token_id is not None:
attn = ids != tokenizer.pad_token_id
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
attn = attn.bool()
# Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently.
device = getattr(getattr(policy, "config", None), "device", None)
if device is not None:
try:
ids = ids.to(device)
if attn is not None and hasattr(attn, "to"):
attn = attn.to(device)
except Exception as exc: # noqa: BLE001
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
new = dict(m)
new.pop("stream", None)
new.pop("target", None)
return new
@dataclass
class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask.
Mirrors the ``high_level_subtask`` recipe layout exactly:
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
user: "Current subtask: ${subtask}" (if subtask present)
↓ generate ↓
assistant: <next subtask>
"""
policy: Any = None
observation_provider: Any = None
"""Same shape as ``LowLevelForward.observation_provider``. When
set, the resulting observation is merged into ``select_message``'s
batch so text generation runs against real video + state."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"):
return None
# ``/vlm`` mode pauses subtask generation along with the rest of
# the action loop.
if state.get("mode", "action") != "action":
return None
# Gate to chunk boundaries: only generate a fresh subtask when
# the action queue is empty (i.e. right before LowLevelForward
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
# and running it every loop iteration starves DispatchAction
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
# of 30/sec and the robot barely moves. Tying it to the same
# "queue empty" condition as the chunk refresh produces a
# clean sense → think → act cycle.
#
# Rearm the trigger when skipping so a low-hz schedule
# (e.g. ``--high_level_hz=0.2`` = once per 5 s) doesn't lose
# the slot: the trigger fires once on the timer but the brief
# queue-empty window almost never coincides, so without rearm
# HL would effectively never run.
queue = state.get("action_queue") or []
if len(queue) > 0:
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
# Per-chunk-boundary throttle: at each "queue empty" moment we
# increment a counter; subtask gen only fires once the counter
# reaches ``subtask_chunks_per_gen``. Lets the operator run e.g.
# 5 action chunks per subtask-gen so the LM head doesn't churn
# every 1.7 s (a fresh subtask while the previous one is still
# being executed is wasted compute *and* causes the action
# expert's flow trajectory to be re-planned mid-grasp).
chunks_per_gen = max(1, int(state.get("subtask_chunks_per_gen", 1) or 1))
# Initialise so the first chunk boundary fires immediately
# (counter starts at chunks_per_gen, decrements per skip,
# generates and resets when it hits 0).
if "_hl_chunks_until_gen" not in state:
state["_hl_chunks_until_gen"] = 0
if state["_hl_chunks_until_gen"] > 0:
state["_hl_chunks_until_gen"] -= 1
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
state["_hl_chunks_until_gen"] = chunks_per_gen - 1
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Default: greedy argmax, no min_new_tokens, no special-token
# suppression — matches training. Operator can override via
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
# on the CLI; useful for under-trained checkpoints whose LM
# head still favours EOS at position 0 (pre-trained chat
# backbone's short-turn prior hasn't been fully overridden
# by the fine-tuning supervision yet).
msg = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
# Subtasks never legitimately contain PaliGemma ``<loc>``
# tokens — suppress them so a checkpoint whose LM head
# has drifted toward the pretrained loc-prior falls back
# to its (still-correct) text mass.
suppress_loc_tokens=True,
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
# repeats. Memorisation collapse looks like "same accepted
# subtask N times in a row" or "gibberish_count rising while
# current_subtask is stuck". The state panel renders these.
state["last_subtask_raw"] = msg or ""
# Persistent empty completion is its own failure mode (model
# immediately EOS-es from the chat-template generation
# prompt) — surface it once every N occurrences so the
# operator can distinguish "generation failing silently"
# from "generating fine but filter rejecting".
if not msg:
empties = state.get("subtask_empty_count", 0) + 1
state["subtask_empty_count"] = empties
if empties == 1 or empties % 5 == 0:
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
if debug:
push_log(
state,
f" [info] subtask gen empty (×{empties}); {debug}",
)
else:
push_log(
state,
f" [info] subtask gen returned empty (×{empties}) — "
"no tokens generated (head EOS-ing before any "
"non-special token).",
)
if msg and _looks_like_gibberish(msg):
# Bump a counter so the operator can see the model is
# struggling without spamming the log every tick. A first
# rejection still logs once so the failure is visible.
count = state.get("subtask_gibberish_count", 0) + 1
state["subtask_gibberish_count"] = count
if count == 1 or count % 30 == 0:
push_log(
state,
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
)
return None
if msg:
prev_subtask = state.get("current_subtask")
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
# Stash the just-completed subtask so ``MemoryUpdateFwd``
# can drop it into its prompt as ``Completed subtask:``
# — the recipe binds ``completed_subtask`` to
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
# that was active *before* the change.
if prev_subtask:
state["prior_subtask"] = prev_subtask
# Subtask change is a downstream trigger.
state.setdefault("events_this_tick", []).append("subtask_change")
state["subtask_repeat_count"] = 0
else:
# Same accepted string regenerated — memorisation tell.
# Once this counter climbs past a few, you're seeing
# the model unable to move past the current subtask
# despite the chunk having drained (visual scene may
# have changed but the LM is replaying training
# tokens).
state["subtask_repeat_count"] = (
state.get("subtask_repeat_count", 0) + 1
)
# Silently skip empty completions — common when the model
# warms up or generates only EOS; logging it every tick at
# ctrl_hz is just noise.
return None
@dataclass
class MemoryUpdateFwd(InferenceStep):
"""On subtask boundary, refresh the compressed memory.
Mirrors the ``memory_update`` recipe layout exactly:
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if prior memory)
user: "Completed subtask: ${completed_subtask}" (if subtask)
↓ generate ↓
assistant: <new memory>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
# Don't consume the event — multiple steps may want to react.
if self.policy is None:
return None
ctx = _msgs_for_memory(state)
observation = _maybe_observation(self.observation_provider)
new_memory = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="memory gen",
suppress_loc_tokens=True,
)
state["last_memory_raw"] = new_memory or ""
if new_memory and _looks_like_gibberish(new_memory):
count = state.get("memory_gibberish_count", 0) + 1
state["memory_gibberish_count"] = count
push_log(
state,
f" [info] memory gen rejected (gibberish ×{count}): {new_memory[:60]!r}",
)
return None
if new_memory:
set_if_changed(state, "current_memory", new_memory, label="memory")
return None
@dataclass
class UserInterjectionFwd(InferenceStep):
"""On stdin interjection, refresh the plan + emit a paired ``say``.
Mirrors the ``user_interjection_response`` recipe layout exactly:
user: "${task}"
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
user: "${interjection}" (the new utterance)
↓ generate ↓
assistant: <plan + <say>...</say>>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_interjection"):
return None
ctx = _msgs_for_interjection(state)
observation = _maybe_observation(self.observation_provider)
out = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="plan/say gen",
suppress_loc_tokens=True,
)
if not out:
# Don't log every empty completion — happens repeatedly on
# MPS during warm-up and floods the panel. The user can
# re-trigger by typing again.
return None
if _looks_like_gibberish(out):
count = state.get("plan_gibberish_count", 0) + 1
state["plan_gibberish_count"] = count
push_log(
state,
f" [info] plan/say gen rejected (gibberish ×{count}): {out[:60]!r}",
)
return None
# Heuristic split: model is trained to emit one assistant turn
# carrying both plan text AND a `say` tool call. Look for a
# "<say>...</say>" or "say(...)" marker; fall back to whole
# text → plan, no speech.
plan_text, speech_text = _split_plan_and_say(out)
if plan_text and _looks_like_gibberish(plan_text):
plan_text = ""
if plan_text:
set_if_changed(state, "current_plan", plan_text, label="plan")
if speech_text:
push_log(state, f" speech: {speech_text}")
state.setdefault("tool_calls_pending", []).append(
{
"type": "function",
"function": {"name": "say", "arguments": {"text": speech_text}},
}
)
state.setdefault("events_this_tick", []).append("tool_call_pending")
# Mark interjection consumed.
state["recent_interjection"] = None
return None
@dataclass
class AskVQAFwd(InferenceStep):
"""On stdin question, answer a frame-grounded VQA.
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
turn carrying just the VQA question, plus the camera image block
in training (we drop the image at inference because the dataset's
image preprocessing doesn't match SmolVLM's vision tower input).
user: <question>
↓ generate ↓
assistant: <vqa answer>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_vqa_query"):
return None
question = state.get("recent_vqa_query")
if not question:
return None
ctx = _msgs_for_vqa(question)
observation = _maybe_observation(self.observation_provider)
answer = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="vqa gen",
)
# VQA answers are intentionally JSON-like during training, so
# ``_looks_like_gibberish`` would false-positive on them. Keep
# the answer as-is — the VQA panel line lets the user judge.
if answer:
push_log(state, f" vqa: {answer}")
state["recent_vqa_query"] = None
return None
# ---------------------------------------------------------------------------
# Tool dispatch
# ---------------------------------------------------------------------------
@dataclass
class DispatchToolCalls(InferenceStep):
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
tools: dict[str, Any] = field(default_factory=dict)
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
take_event(state, "tool_call_pending")
pending = state.get("tool_calls_pending") or []
for call in pending:
try:
fn = (call or {}).get("function") or {}
name = fn.get("name")
args = fn.get("arguments") or {}
tool = self.tools.get(name)
if tool is None:
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
continue
tool.call(args)
except Exception as exc: # noqa: BLE001
push_log(state, f" [error] tool dispatch failed: {exc}")
state["tool_calls_pending"] = []
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _looks_like_gibberish(text: str) -> bool:
"""Heuristically detect generation that's clearly off the rails.
Memorised models can collapse to dominant-mode outputs when the
prompt drifts even slightly from training distribution. Reject:
* empty / whitespace-only
* too few alphabetic characters (mostly punctuation)
* a single character repeated past the threshold
* starts with ``":"`` and contains no letters
* too few unique tokens — e.g. ``"the"``, ``"the the the"``,
``"Ass\\n::\\nthe"`` (the collapse seen on real-robot frames
where the model emits one or two memorised tokens repeatedly)
* chat-template fragment leakage (``Assistant:``, ``User:``,
``Ass\\n``)
Real subtasks look like ``"close the gripper to grasp the blue
cube"`` — multiple unique alphabetic tokens, no role-marker
fragments. Anything materially shorter than that is rejected.
"""
if not text or not text.strip():
return True
stripped = text.strip()
alpha = sum(1 for c in stripped if c.isalpha())
if alpha < max(3, len(stripped) // 8):
return True
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
return True
# Single repeating char: e.g. ``""""""``.
if len(set(stripped)) <= 2 and len(stripped) > 4:
return True
# Chat-template fragment leakage — the model emits ``Ass``,
# ``Assistant:``, ``User:``, often with extra newlines/colons.
# Reject if the cleaned text is mostly role-marker shards.
cleaned = stripped.replace("\n", " ").replace(":", " ")
for marker in ("Assistant", "User", "Ass "):
if marker in cleaned and len(cleaned.split()) < 4:
return True
tokens = [t for t in cleaned.split() if any(c.isalpha() for c in t)]
unique_alpha = {t.lower() for t in tokens}
# Short degenerate output — model stuck on ``the`` or a couple of
# memorised single-token continuations.
if len(unique_alpha) < 3 and len(stripped) < 80:
return True
# Long repetition collapse — the LM head loops an n-gram for the
# whole generation budget ("the arm the arm … the the the the").
# Length-independent: many tokens but a tiny unique ratio. The
# earlier ``< 80`` check missed these because the looped string
# blows well past 80 chars.
if len(tokens) >= 8 and len(unique_alpha) <= max(3, len(tokens) // 10):
return True
return False
def _control_context_messages(
state: dict[str, Any],
*,
include_completed: bool = False,
extra_user: str | None = None,
) -> list[dict[str, Any]]:
"""Build a chat-template-ready prompt from current runtime state.
Mirrors what ``subtasks_vqa.yaml`` renders into ``${task}\nPlan:
${plan}\nMemory: ${memory}`` for the high-level branches.
"""
# Always emit ``Plan: `` / ``Memory: `` labels — even with empty
# values — to mirror the training-time recipe substitution.
task = state.get("task") or ""
plan = state.get("current_plan") or ""
memory = state.get("current_memory") or ""
parts = [task, f"Plan: {plan}", f"Memory: {memory}"]
if include_completed and state.get("current_subtask"):
parts.append(f"Completed subtask: {state['current_subtask']}")
head = "\n".join(parts)
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
if extra_user:
msgs.append({"role": "user", "content": extra_user})
return msgs
# ---------------------------------------------------------------------------
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
# message layout in ``subtasks_vqa.yaml`` so the chat-templated
# prompt at inference matches what the model saw during training.
# Generic ``_control_context_messages`` is kept around as a fallback
# for ad-hoc callers but the four high-level steps now use these.
# ---------------------------------------------------------------------------
def _hirobot_user_head(state: dict[str, Any]) -> str:
"""Build the ``task\\nPlan: …\\nMemory: …`` user content string.
Mirrors what the recipe renders at training time, where
``language_render._substitute`` substitutes empty strings for
missing ``${plan}`` / ``${memory}`` bindings — i.e. the
``Plan: `` / ``Memory: `` prefix labels are *always* in the
user turn, even when their values aren't set yet. Skipping them
here (the previous behaviour) produced a different prompt shape
on early frames before plan / memory are populated and on
samples where the dataset has no plan / memory annotation.
"""
task = state.get("task") or ""
plan = state.get("current_plan") or ""
memory = state.get("current_memory") or ""
return f"{task}\nPlan: {plan}\nMemory: {memory}"
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``high_level_subtask`` recipe layout — predict the subtask from the
task. The v-current recipe's user turn is just ``${task}`` (plan and
memory are not trained), so the inference prompt is the bare task —
no ``Plan: `` / ``Memory: `` lines.
"""
return [{"role": "user", "content": state.get("task") or ""}]
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
Recipe layout (``subtask_mem.yaml``):
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if_present prior)
user: "Completed subtask: ${completed}" (if_present completed)
assistant: → predicts new memory
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
``state['current_memory']`` is the memory the policy last emitted
(= the ``prior_memory`` binding at training), and
``state['prior_subtask']`` is the subtask that just got replaced
(= the ``completed_subtask`` binding at training).
"""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""},
]
prior_memory = state.get("current_memory")
if prior_memory:
msgs.append(
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
)
completed_subtask = state.get("prior_subtask")
if completed_subtask:
msgs.append(
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
)
return msgs
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``user_interjection_response`` recipe layout."""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""}
]
if state.get("current_plan"):
msgs.append(
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
)
interjection = state.get("recent_interjection")
if interjection:
msgs.append({"role": "user", "content": interjection})
return msgs
def _msgs_for_plan(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``plan_generation`` recipe layout — bare task → plan.
The assistant turn is the generation target, so we only render
the user turn at inference; the runtime appends the predicted
plan after sampling.
"""
return [{"role": "user", "content": state.get("task") or ""}]
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
return [{"role": "user", "content": question}]
def _maybe_observation(provider: Any) -> dict | None:
"""Pull one observation from ``provider`` if it's set, else ``None``.
Errors from the provider are logged at debug level and swallowed —
text generation still runs (in text-only mode) so a flaky frame
source doesn't kill the REPL.
"""
if provider is None:
return None
try:
return provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s — falling back to text-only", exc)
return None
def _generate_with_policy(
policy: Any,
messages: list[dict[str, Any]],
*,
observation: dict | None = None,
state: dict[str, Any] | None = None,
label: str = "select_message",
min_new_tokens: int = 0,
temperature: float = 0.0,
top_p: float = 1.0,
suppress_loc_tokens: bool = False,
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
When ``observation`` carries ``observation.images.*`` and
``observation.state``, those are merged into the batch so
``select_message`` runs the same VLM prefix the policy was trained
on. Without an observation the runtime falls back to a text-only
prompt — the text head still runs, but generations may drift from
the training distribution.
Failures are surfaced both to the module logger (``warning``) and,
when ``state`` is given, to the runtime's user-visible log via
:func:`push_log`, so the REPL no longer "looks dead" when
something goes wrong inside generation.
"""
if not hasattr(policy, "select_message"):
if state is not None:
push_log(state, f" [warn] policy has no select_message — skipping {label}")
return ""
text_batch = _build_text_batch(policy, messages)
try:
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
batch: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
}
if observation:
for k, v in observation.items():
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
batch[k] = v
kwargs: dict[str, Any] = {
"tokenizer": text_batch["tokenizer"],
"min_new_tokens": min_new_tokens,
"temperature": temperature,
"top_p": top_p,
}
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
return policy.select_message(batch, **kwargs)
except Exception as exc: # noqa: BLE001
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
if state is not None:
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
return ""
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
def _split_plan_and_say(text: str) -> tuple[str, str]:
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
The training-time tool-call serializer wraps ``say(text="")`` in a
deterministic textual marker so prefix-LM-style training learns to
emit it. The runtime parses it back here. If no marker is present,
the entire text is treated as plan with no speech.
"""
if not text:
return "", ""
match = _SAY_RE.search(text)
if not match:
return text.strip(), ""
speech = match.group(1).strip().strip('"').strip("'")
plan = (text[: match.start()] + text[match.end() :]).strip()
return plan, speech
@@ -0,0 +1,134 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trigger primitives for PI052's multi-rate inference runtime.
Mirrors the plan's Section "Runtime orchestration": each
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
whether the step fires. Two trigger flavours cover all the cadences
the canonical recipe needs:
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
memory update; user interjection → plan refresh; user VQA query →
vqa answer; pending tool call → dispatcher)
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
every step shares a single time source.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any, Protocol
@dataclass
class Tick:
"""Single tick from :class:`TickClock`. Carries time references the
runtime steps consume to gate themselves."""
index: int
"""Monotonic counter — increments by one per tick."""
monotonic_seconds: float
"""``time.monotonic()`` at the start of this tick."""
@dataclass
class TickClock:
"""Drives the runtime loop at up to ``max_rate_hz``.
Sleeps just enough between :meth:`advance` calls to enforce the
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
"""
max_rate_hz: float = 50.0
_index: int = field(default=0, init=False)
_last_seconds: float | None = field(default=None, init=False)
def advance(self) -> Tick:
period = 1.0 / max(self.max_rate_hz, 0.1)
now = time.monotonic()
if self._last_seconds is not None:
sleep_for = (self._last_seconds + period) - now
if sleep_for > 0:
time.sleep(sleep_for)
now = time.monotonic()
self._last_seconds = now
self._index += 1
return Tick(index=self._index, monotonic_seconds=now)
class Trigger(Protocol):
"""Decide whether the next ``InferenceStep`` should fire."""
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
@dataclass
class HzTrigger:
"""Fire at most ``hz`` times per second.
A step that gates further (e.g. ``HighLevelSubtaskFwd`` skipping
when the action queue is non-empty) and wants the trigger to
retry next tick instead of waiting a full period can call
:meth:`rearm` from inside ``run``. Without this, a low-hz trigger
(e.g. ``hz=0.2`` = once per 5 s) almost never coincides with the
brief queue-empty window and the step never fires at all.
"""
hz: float
_last_seconds: float | None = field(default=None, init=False)
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
period = 1.0 / max(self.hz, 1e-6)
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
self._last_seconds = tick.monotonic_seconds
return True
return False
def rearm(self) -> None:
"""Mark the trigger as not having fired, so the next tick re-evaluates.
Used by a step that decided to skip after ``should_fire`` already
committed the firing — keeps the cadence honest without losing
the slot.
"""
self._last_seconds = None
@dataclass
class EventTrigger:
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
The runtime fills ``events_this_tick`` once per tick from:
* stdin / network input (``user_interjection``, ``user_vqa_query``,
``stop``)
* internal state transitions (``subtask_change``,
``tool_call_pending``)
The list is consumed (cleared at the end of the tick) so events
fire at most once.
"""
event_name: str
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
events: list[str] = state.get("events_this_tick") or []
return self.event_name in events
+127
View File
@@ -0,0 +1,127 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rich-based REPL layout for the PI052 runtime.
Two-zone terminal layout:
[chat scrollback — user messages / robot responses, scrolls naturally]
┌── State ──────────────────────────────────────────┐
│ task please clean up the kitchen │
│ subtask grasp the handle of the sponge │
│ plan 1. grasp sponge 2. wipe 3. tidy │
│ memory sponge picked up; counter still dirty │
└───────────────────────────────────────────────────┘
> _
The state panel re-renders on every state change. Chat lines are
``console.print``'d above the live region so they accumulate naturally
in scrollback. Implemented with :class:`rich.live.Live` plus
:func:`rich.console.Console.input` for the prompt — when an input is
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
panel for cursor position.
"""
from __future__ import annotations
from typing import Any
try: # rich is optional; only required for the interactive REPL.
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
_HAS_RICH = True
except ImportError: # pragma: no cover
_HAS_RICH = False
Console = Any # type: ignore[assignment]
Panel = Any # type: ignore[assignment]
Table = Any # type: ignore[assignment]
Text = Any # type: ignore[assignment]
_STATE_KEYS = (
("task", "task"),
("current_subtask", "subtask"),
("current_plan", "plan"),
("current_memory", "memory"),
)
def make_state_panel(state: dict[str, Any]) -> Any:
"""Render the persistent state panel for the live region.
Returns a :class:`rich.panel.Panel`. Caller passes it to
``Live.update(panel)`` whenever the state changes.
"""
if not _HAS_RICH:
raise RuntimeError(
"rich is required for the interactive REPL. "
"`pip install rich` (it's a transitive dep of lerobot)."
)
table = Table.grid(padding=(0, 2), expand=True)
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
table.add_column(justify="left")
for key, label in _STATE_KEYS:
value = state.get(key)
if value is None:
rendered = Text("(not set)", style="dim italic")
else:
rendered = Text(str(value), style="bold")
table.add_row(label, rendered)
queue = state.get("action_queue")
queue_len = len(queue) if hasattr(queue, "__len__") else 0
pending = state.get("tool_calls_pending") or []
footer = Text.assemble(
("queued actions: ", "dim"),
(str(queue_len), "bold cyan"),
(" pending tool calls: ", "dim"),
(str(len(pending)), "bold magenta"),
)
table.add_row("", footer)
run_mode = state.get("mode", "action")
mode_tag = (
"[green]action[/]" if run_mode == "action" else "[yellow]paused[/]"
)
return Panel(
table,
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
border_style="cyan",
)
def print_user_line(console: Any, line: str) -> None:
"""Append a user-typed line to the chat scrollback."""
if not _HAS_RICH:
print(f"you: {line}", flush=True)
return
console.print(f"[bold cyan]you:[/] {line}")
def print_robot_lines(console: Any, lines: list[str]) -> None:
"""Append robot/runtime log lines to the chat scrollback."""
if not _HAS_RICH:
for line in lines:
print(f"robot: {line.lstrip()}", flush=True)
return
for line in lines:
# The runtime uses leading whitespace + "label: text"; render
# the label in green and the value in default for readability.
stripped = line.lstrip()
if ":" in stripped:
label, _, value = stripped.partition(":")
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
else:
console.print(f"[bold green]robot:[/] {stripped}")
+423
View File
@@ -0,0 +1,423 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interactive VQA for the PI052 runtime.
In ``/vlm`` mode a typed line is treated as a VQA question. This module
runs the full interactive flow:
1. pull the current observation and list available cameras,
2. ask the operator which camera to ground the question on,
3. generate the answer with the VLM conditioned on that one camera,
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
point (``keypoint``), draw the overlay on the camera frame, save a
PNG to ``./vqa_overlays/`` and auto-open it.
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
(see ``lerobot.annotations.steerable_pipeline.validator``):
* ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}, ...]}``
* ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}``
* ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay.
"""
from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import sys
import time
import webbrowser
from pathlib import Path
from typing import Any
from .runtime_state import push_log
logger = logging.getLogger(__name__)
_IMAGE_PREFIX = "observation.images."
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
# normalized to the image axis) instead of pixel-coordinate JSON, so the
# answer string the runtime parses can be e.g.
# ``<loc0512><loc0301> blue cube`` (point) or
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
# Iteration order for shape matching — most specific keys first so an
# answer is classified deterministically.
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
_BBOX_COLOR = (255, 64, 64)
_POINT_COLOR = (64, 220, 64)
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def available_cameras(observation: dict | None) -> list[str]:
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
if not observation:
return []
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
def camera_short_name(camera_key: str) -> str:
"""Strip the ``observation.images.`` prefix for display."""
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
def prompt_camera_choice(
cameras: list[str],
*,
input_fn: Any = input,
print_fn: Any = print,
) -> str | None:
"""Ask the operator which camera frame to draw a VQA overlay on.
Accepts either the menu number or the (short or full) camera name.
A single-camera setup auto-selects without prompting. Returns the
chosen ``observation.images.*`` key, or ``None`` if the operator
cancels / gives an invalid answer.
"""
if not cameras:
return None
if len(cameras) == 1:
return cameras[0]
print_fn("Draw the result on which camera?")
for i, cam in enumerate(cameras, 1):
print_fn(f" [{i}] {camera_short_name(cam)}")
try:
raw = str(input_fn("camera> ")).strip()
except (EOFError, KeyboardInterrupt):
return None
if not raw:
return cameras[0]
if raw.isdigit():
idx = int(raw) - 1
return cameras[idx] if 0 <= idx < len(cameras) else None
for cam in cameras:
if raw == cam or raw == camera_short_name(cam):
return cam
return None
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def _loc_to_norm(idx: int) -> float:
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
return max(0.0, min(1023.0, float(idx))) / 1023.0
def parse_loc_answer(answer: str) -> dict | None:
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
PI052 trains spatial answers in PaliGemma's native detection
vocabulary, label-first: a point is ``<label> <locY><locX>``, a box
is ``<label> <locY0><locX0><locY1><locX1>``, and multiple boxes are
joined by `` ; `` (e.g. ``cube <loc..><loc..><loc..><loc..> ; box
<loc..><loc..><loc..><loc..>``). Loc-first formats are also accepted
— this parser strips loc tokens and treats the remainder as the
label, so order is irrelevant. Coordinates come back *normalized*
([0, 1]); the overlay denormalizes them against the chosen camera
frame's pixel size.
Returns ``{"kind", "payload", "normalized": True}`` on success
(``payload`` mirrors the JSON shapes so the overlay code is shared),
or ``None`` when the answer carries no ``<loc>`` tokens.
"""
if not answer or "<loc" not in answer:
return None
segments = [seg for seg in answer.split(";") if "<loc" in seg]
points: list[tuple[float, float, str]] = []
boxes: list[tuple[float, float, float, float, str]] = []
for seg in segments:
locs = [int(m) for m in _LOC_RE.findall(seg)]
label = _LOC_RE.sub("", seg).strip()
if len(locs) == 2:
y, x = (_loc_to_norm(v) for v in locs[:2])
points.append((x, y, label))
elif len(locs) >= 4:
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
boxes.append((x1, y1, x2, y2, label))
if boxes:
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
for (x1, y1, x2, y2, lbl) in boxes
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
if len(points) == 1:
x, y, lbl = points[0]
return {
"kind": "keypoint",
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
"normalized": True,
}
if points: # several bare points → treat as detections-as-points
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
return None
def parse_vqa_answer(answer: str) -> dict | None:
"""Parse a VQA answer string into ``{"kind", "payload"}``.
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
spatial answers are detected first (PI052 trains them in that native
format). Returns ``None`` when the answer is neither ``<loc>`` text
nor a parseable JSON object.
"""
if not answer or not answer.strip():
return None
loc_parsed = parse_loc_answer(answer)
if loc_parsed is not None:
return loc_parsed
try:
payload = json.loads(answer)
except (ValueError, TypeError):
return None
if not isinstance(payload, dict):
return None
try:
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
VQA_ANSWER_SHAPES,
)
shapes = VQA_ANSWER_SHAPES
except ImportError: # pragma: no cover - annotation extra not installed
shapes = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
keys = set(payload)
for kind in _SHAPE_ORDER:
required = shapes.get(kind)
if required and required <= keys:
return {"kind": kind, "payload": payload}
return {"kind": "unknown", "payload": payload}
def answer_has_overlay(parsed: dict | None) -> bool:
"""True iff ``parsed`` carries drawable spatial coordinates."""
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def observation_image_to_pil(image_tensor: Any) -> Any:
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
The runtime observation stores images as ``(1, C, H, W)`` (or
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
``image_array_to_pil_image`` which handles the CHW→HWC transpose and
the float→uint8 scaling.
"""
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
arr = image_tensor
if hasattr(arr, "detach"):
arr = arr.detach().cpu()
if hasattr(arr, "numpy"):
arr = arr.numpy()
while arr.ndim > 3: # drop leading batch dim(s)
arr = arr[0]
return image_array_to_pil_image(arr).convert("RGB")
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
``unknown``) are returned as an unmodified copy. When ``parsed`` has
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
coordinates are scaled to the image's pixel size.
"""
from PIL import ImageDraw # noqa: PLC0415
img = image.convert("RGB").copy()
kind = parsed.get("kind")
payload = parsed.get("payload") or {}
draw = ImageDraw.Draw(img)
w, h = img.size
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
if kind == "bbox":
for det in payload.get("detections") or []:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
x1, x2 = x1 * sx, x2 * sx
y1, y2 = y1 * sy, y2 * sy
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
label = str(det.get("label", "")).strip()
if label:
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
elif kind == "keypoint":
point = payload.get("point")
if isinstance(point, list | tuple) and len(point) == 2:
try:
x, y = float(point[0]) * sx, float(point[1]) * sy
except (TypeError, ValueError):
return img
r = 6
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
label = str(payload.get("label", "")).strip()
if label:
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
return img
def _open_file(path: Path) -> None:
"""Best-effort open ``path`` in the OS default viewer."""
try:
if sys.platform == "darwin":
subprocess.run(["open", str(path)], check=False)
elif sys.platform.startswith("linux"):
subprocess.run(["xdg-open", str(path)], check=False)
elif os.name == "nt":
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
else: # pragma: no cover - exotic platform
webbrowser.open(path.resolve().as_uri())
except Exception as exc: # noqa: BLE001
logger.debug("could not auto-open %s: %s", path, exc)
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
path = out / f"vqa_{int(time.time() * 1000)}.png"
image.save(path)
_open_file(path)
return path
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------
def handle_vqa_query(
*,
policy: Any,
observation_provider: Any,
question: str,
state: dict[str, Any],
input_fn: Any = input,
print_fn: Any = print,
) -> None:
"""Run one interactive VQA question end to end.
Called synchronously from the input layer while the runtime is in
``/question`` mode (the action loop is gated off, so the policy is
not in concurrent use). Progress is reported via both
:func:`push_log` (REPL panel scrollback) and ``print_fn`` (direct
stdout) — in autonomous question mode the panel redraw is suspended,
so the direct print is what the operator actually sees.
"""
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
def report(line: str) -> None:
"""Surface a line both to the panel scrollback and to stdout."""
push_log(state, line)
try:
print_fn(line)
except Exception: # noqa: BLE001
pass
if policy is None or not hasattr(policy, "select_message"):
report(" [warn] vqa: policy has no select_message — skipping")
return
observation: dict | None = None
if observation_provider is not None:
try:
observation = observation_provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s", exc)
# Feed the FULL observation (every camera + state) to the VLM. The
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
# stripped before tokenization — the actual frames reach the model
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
# consumes *all* ``config.image_features`` regardless of which
# camera the sub-recipe was tagged for. So the model always sees
# every camera; the operator never has to name one to ask.
answer = _generate_with_policy(
policy,
_msgs_for_vqa(question),
observation=observation,
state=state,
label="vqa gen",
)
if not answer:
report(" [info] vqa gen returned empty")
return
report(f" vqa: {answer}")
parsed = parse_vqa_answer(answer)
if not answer_has_overlay(parsed):
if parsed is None:
report(" [info] vqa answer is not JSON — no overlay")
return
# The answer carries a bounding box / point. Its pixel coordinates
# are camera-specific and the text answer doesn't say which camera,
# so ask the operator *now* — only when there is actually something
# to draw — which camera frame to render the overlay on.
cameras = available_cameras(observation)
if observation is None or not cameras:
report(" [info] no camera image — cannot draw overlay")
return
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
if chosen is None:
report(" [info] overlay skipped — no camera selected")
return
try:
pil = observation_image_to_pil(observation[chosen])
overlay = draw_vqa_overlay(pil, parsed)
path = save_and_open_overlay(overlay)
report(f" vqa overlay ({camera_short_name(chosen)}) saved: {path}")
except Exception as exc: # noqa: BLE001
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
report(f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,198 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 pre/post-processor factory.
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
rename observations
add batch dim
relative-action prep (inherited from π0.5)
NormalizerProcessorStep
RenderMessagesStep recipe messages, target_message_indices,
message_streams (PR 1 of the steerable
stack)
PI052TextTokenizerStep messages input_ids + label mask +
predict_actions
DeviceProcessorStep
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
so unannotated datasets keep working.
Post-processor is unchanged from π0.5.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import (
AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
# RenderMessagesStep is intentionally not re-exported from
# ``lerobot.processor`` because it pulls in optional language-stack deps;
# import it directly.
from lerobot.processor.render_messages_processor import RenderMessagesStep
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
from .configuration_pi052 import PI052Config
from .text_processor_pi052 import PI052TextTokenizerStep
def make_pi052_pre_post_processors(
config: PI052Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_repo_id: str | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build PI0.5-v2's pre/post-processor pipelines.
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
"""
if not config.recipe_path:
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
recipe = _load_recipe(config.recipe_path)
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
relative_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
RenderMessagesStep(recipe=recipe),
PI052TextTokenizerStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
),
]
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
# Only inserted when explicitly enabled — keeps the post-training-
# style recipe (flow + text) as the default. When on, the step
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
if getattr(config, "enable_fast_action_loss", False):
# Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the
# tokenizer on this dataset's action distribution rather than
# using the universal codebook off the shelf. We do this once
# and cache to disk, keyed on (dataset, base, n_samples).
action_tokenizer_path = config.action_tokenizer_name
if (
getattr(config, "auto_fit_fast_tokenizer", False)
and dataset_repo_id is not None
):
from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415
cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser()
try:
action_tokenizer_path = fit_fast_tokenizer(
dataset_repo_id=dataset_repo_id,
cache_dir=cache_dir,
base_tokenizer_name=config.action_tokenizer_name,
n_samples=config.fast_tokenizer_fit_samples,
chunk_size=config.chunk_size,
)
except Exception as exc: # noqa: BLE001
import logging # noqa: PLC0415
logging.getLogger(__name__).warning(
"FAST tokenizer fit failed (%s) — falling back to "
"the universal base tokenizer %r. Train will still "
"work but compression will be suboptimal.",
exc, config.action_tokenizer_name,
)
input_steps.append(
ActionTokenizerProcessorStep(
action_tokenizer_name=action_tokenizer_path,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
)
)
input_steps.append(DeviceProcessorStep(device=config.device))
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AbsoluteActionsProcessorStep(
enabled=config.use_relative_actions,
relative_step=relative_step,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
def _load_recipe(path_str: str) -> TrainingRecipe:
"""Resolve ``path_str`` to a ``TrainingRecipe``.
Accepts an absolute path or a path relative to
``src/lerobot/configs/``.
"""
p = Path(path_str)
if not p.is_absolute() and not p.exists():
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
configs_dir = Path(_recipe_module.__file__).resolve().parent
candidate = configs_dir / path_str
if candidate.exists():
p = candidate
return TrainingRecipe.from_yaml(p)
@@ -0,0 +1,641 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 text-tokenisation step.
PaliGemma is *not* chat-pretrained, so we can't lean on
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
messages as plain text with simple ``User: ... Assistant: ...`` role
delimiters matching the prompt format π0.5 uses in the paper
(``Task: ... State: ... Action: ...``).
Outputs:
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` the
concatenated prompt tokenised by the PaliGemma tokenizer (the same
one ``processor_pi05`` already uses).
* ``text_labels`` same shape as token ids, ``-100`` everywhere except
positions belonging to messages whose index is in
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
those positions via the PaliGemma ``lm_head``.
* ``predict_actions`` bool tensor, ``True`` iff any of the rendered
target messages has ``message_streams[i] == "low_level"``.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
logger = logging.getLogger(__name__)
def discretize_state_str(state_row: Any) -> str:
"""Discretize a single normalized state vector into 256 bins, space-joined.
Mirrors pi05's ``Pi05PrepareStateTokenizerProcessorStep`` (same bins /
convention) so pi052's low-level action prompt carries proprioception in
the exact format pi05 was trained on. Expects state already normalized by
the upstream ``NormalizerProcessorStep``.
"""
arr = state_row.detach().cpu().numpy() if hasattr(state_row, "detach") else np.asarray(state_row)
disc = np.digitize(arr, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
return " ".join(str(int(x)) for x in disc.reshape(-1).tolist())
def _state_row_at(state_all: Any, pos: int) -> Any:
"""Select the per-sample state row from a (possibly batched) state tensor."""
if state_all is None:
return None
if hasattr(state_all, "ndim") and state_all.ndim >= 2:
return state_all[pos]
return state_all
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
b["text"]
for b in content
if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str)
]
return "\n".join(parts)
return ""
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
"""Serialize assistant ``say`` tool calls into a ``<say>...</say>`` marker.
PaliGemma's flat text prompt has no notion of structured tool calls,
and ``_format_messages`` only reads ``role`` / ``content`` so
without this a ``say`` tool call is dropped entirely and never
supervised. Rewriting it into the content text as a ``<say>...</say>``
marker lets the LM head learn to emit it; the runtime parses it back
via ``_split_plan_and_say``. Messages without ``say`` tool calls are
returned unchanged (the structured calls, if any, are still dropped).
"""
tool_calls = message.get("tool_calls")
if not tool_calls:
return message
say_texts: list[str] = []
for call in tool_calls:
if not isinstance(call, dict):
continue
fn = call.get("function") or {}
if fn.get("name") != "say":
continue
args = fn.get("arguments")
if isinstance(args, str):
try:
import json # noqa: PLC0415
args = json.loads(args)
except (ValueError, TypeError):
args = {}
text = args.get("text", "") if isinstance(args, dict) else ""
if text:
say_texts.append(str(text))
new = dict(message)
new.pop("tool_calls", None)
if not say_texts:
return new
base = _content_to_text(new.get("content")).strip()
marker = "".join(f"<say>{t}</say>" for t in say_texts)
new["content"] = f"{base}\n{marker}" if base else marker
return new
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Normalise a message's content to a plain string.
The recipe renderer can emit ``content`` as a string OR as a list
of HF-style multimodal blocks (``{type: text, text: ...}``,
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
only consume strings, so we flatten: drop image blocks (cameras
flow through ``observation.images.*`` separately) and join text
block texts.
"""
new = dict(message)
new.pop("stream", None)
new.pop("target", None)
content = new.get("content")
if content is None:
new["content"] = ""
elif isinstance(content, str):
pass
elif isinstance(content, list):
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
t = block.get("text", "")
if isinstance(t, str):
parts.append(t)
new["content"] = "\n".join(parts)
else:
new["content"] = str(content)
return new
def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
if value is None:
return [None] * batch_size
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return [int(value.item())] * batch_size
values = value.reshape(-1).tolist()
return [int(v) for v in values[:batch_size]]
if isinstance(value, (list, tuple)):
if len(value) == 1:
return _sample_indices(value[0], batch_size)
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
return [int(value)] * batch_size
# ---------------------------------------------------------------------------
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
#
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
# answers are stored as JSON in Qwen2.5-VL's grounding convention:
# **01000 normalized coordinates**, NOT pixels. (Verified empirically
# on the published datasets: x and y both span 0..1000 with ~30% of
# values exceeding the camera's pixel dimensions — they're not pixels.)
# Converting to ``<loc>`` is therefore camera-resolution-independent:
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
# not in the dataset — so the dataset keeps the raw JSON and stays
# backbone-agnostic.
# ---------------------------------------------------------------------------
# The 01000 scale Qwen2.5-VL emits for grounding coordinates.
_VQA_COORD_SCALE = 1000.0
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
(detection / pointing) tokens, but the *stock* tokenizer does NOT
match them when encoding raw text it BPE-splits ``<loc0162>`` into
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
the LM head on a ``<loc>`` target then supervises those 7 generic
BPE pieces instead of one detection-vocab id, the LM head learns to
emit the *character sequence*, and those pieces' logits dominate
other turns (the ``<loc>``-salad on subtasks). Registering the loc
tokens once makes them tokenize as their single ids (256000+idx),
leveraging PaliGemma's detection prior properly. Idempotent.
"""
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
return tokenizer
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
return tokenizer
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
return f"<loc{max(0, min(1023, idx)):04d}>"
def _vqa_answer_to_loc(answer: dict[str, Any]) -> str | None:
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
Input coordinates are in Qwen2.5-VL's 01000 normalized space (see
module-level note). y is emitted before x for each coordinate pair
(PaliGemma convention), with the integer indices in [0, 1023].
**Format: label first, locs after.** PaliGemma's pretraining puts
locs first (``<loc><loc> label``), but for our small-dataset VQA
blend that turns the LM head into a loc-emission attractor at every
``Assistant:`` position VQA targets share their first supervised
token with ~25% of all text samples, and the head collapses to
emitting ``<loc>`` regardless of the prompt. Putting the label
first (``label <locY><locX>``) means every text sample (subtask,
memory, VQA, ) starts the supervised target with a real word,
breaking the attractor. The model still learns the loc vocabulary
for the *spatial* portion of the answer; it just can't fire it as
the first generation step from a clean prompt.
Returns ``None`` for non-spatial answers (count / attribute /
spatial-relation) those keep their JSON form.
"""
point = answer.get("point")
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
try:
x, y = float(point[0]), float(point[1])
except (TypeError, ValueError):
return None
label = str(answer.get("label", "")).strip()
if not label:
return None
return f"{label} {_loc_token(y)}{_loc_token(x)}"
detections = answer.get("detections")
if isinstance(detections, list) and detections:
parts: list[str] = []
for det in detections:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
label = str(det.get("label", "")).strip()
if not label:
continue
toks = (
f"{_loc_token(y1)}{_loc_token(x1)}"
f"{_loc_token(y2)}{_loc_token(x2)}"
)
parts.append(f"{label} {toks}")
return " ; ".join(parts) if parts else None
return None
def _messages_vqa_to_loc(
messages: list[dict[str, Any]],
target_indices: list[int],
) -> list[dict[str, Any]]:
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
Each target turn whose content parses as a spatial VQA answer is
converted. Non-spatial answers and subtask / memory targets (plain
text not JSON) are left untouched. Camera-independent: VQA coords
are 01000 normalized, so no observation lookup is needed.
"""
if not target_indices:
return messages
out = list(messages)
for idx in target_indices:
if not (0 <= idx < len(out)):
continue
content = out[idx].get("content")
if not isinstance(content, str) or not content.strip():
continue
try:
answer = json.loads(content)
except (ValueError, TypeError):
continue # subtask / memory targets are plain text — skip
if not isinstance(answer, dict):
continue
loc_text = _vqa_answer_to_loc(answer)
if loc_text is not None:
out[idx] = {**out[idx], "content": loc_text}
return out
def _format_messages(
messages: list[dict[str, Any]],
target_indices: list[int] | None = None,
eos_token: str | None = None,
) -> tuple[str, list[tuple[int, int]]]:
"""Concatenate messages into the π0.5-style flat prompt.
When both ``target_indices`` and ``eos_token`` are given, the EOS
string is appended to each supervised target turn's content and the
returned span covers it so the label builder marks the EOS token
as a supervised label. That teaches the LM head where the answer
*ends*: without an EOS in the target span the model is never given a
stop signal and rambles to ``max_length`` at inference. Inference
callers omit both args (no EOS baked into the prompt the model
generates it and ``select_message`` stops on it).
Returns:
prompt: the full text the tokenizer will consume.
msg_spans: list of ``(char_start, char_end)`` covering each
message's supervised payload (content, plus the
appended EOS for target turns) within ``prompt``.
"""
targets = set(target_indices or [])
parts: list[str] = []
spans: list[tuple[int, int]] = []
cursor = 0
for i, m in enumerate(messages):
role = m.get("role", "user")
content = m.get("content", "") or ""
# Role tag + newline. The model has to learn to emit the same
# role tokens at generation time, which is fine for greedy
# decoding because the chat template is implicit in the
# supervised target span.
header = f"{role.capitalize()}: "
# A supervised target turn ends with EOS so the model learns to
# terminate; the span below covers content + EOS. Non-target
# turns (and inference) carry no EOS.
body = content + eos_token if (eos_token and i in targets) else content
# span covers the content (+ EOS) portion only — never the role
# tag — so labels are computed over the supervised payload.
full = header + body + "\n"
start = cursor + len(header)
end = start + len(body)
parts.append(full)
spans.append((start, end))
cursor += len(full)
return "".join(parts), spans
@dataclass
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
class PI052TextTokenizerStep(ProcessorStep):
"""Render messages → token ids + label mask + predict_actions flag.
No chat template; concatenates messages as
``User: ... \\nAssistant: ...`` text.
"""
tokenizer_name: str = "google/paligemma-3b-pt-224"
max_length: int = 200
padding: str = "max_length"
padding_side: str = "right"
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
interjection_dropout_prob: float = 0.0
dropout_seed: int | None = None
def __post_init__(self) -> None:
self._tokenizer: Any = None
def _ensure_tokenizer(self) -> Any:
if self._tokenizer is not None:
return self._tokenizer
from transformers import AutoTokenizer # noqa: PLC0415
self._tokenizer = register_paligemma_loc_tokens(
AutoTokenizer.from_pretrained(self.tokenizer_name)
)
return self._tokenizer
# ------------------------------------------------------------------
# Pipeline step
# ------------------------------------------------------------------
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
transition = transition.copy()
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
messages = complementary.get("messages") or []
if not messages:
# No recipe was rendered — caller will fall back to the
# plain Pi0.5 prompt path. We pass the transition through
# unmodified.
return transition
tokenizer = self._ensure_tokenizer()
# Normalized proprioceptive state (set by NormalizerProcessorStep, which
# runs before this step). Injected into low-level action prompts so the
# action expert sees proprioception, matching pi05's discretized State:.
state_all = (transition.get(TransitionKey.OBSERVATION) or {}).get(OBS_STATE)
# VQA coords are 01000 normalized (Qwen2.5-VL convention) — the
# <loc> conversion is camera-resolution-independent and needs no
# observation lookup here.
if _is_batched_messages(messages):
indices_iter = _sample_indices(complementary.get("index"), len(messages))
encoded = [
self._encode_messages(
tokenizer,
msg,
list(streams),
list(tgt_indices),
complementary,
sample_idx=int(s_idx) if s_idx is not None else None,
state_row=_state_row_at(state_all, pos),
)
for pos, (msg, streams, tgt_indices, s_idx) in enumerate(
zip(
messages,
complementary.get("message_streams") or [[] for _ in messages],
complementary.get("target_message_indices") or [[] for _ in messages],
indices_iter,
strict=False,
)
)
]
else:
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
encoded = [
self._encode_messages(
tokenizer,
messages,
list(complementary.get("message_streams") or []),
list(complementary.get("target_message_indices") or []),
complementary,
sample_idx=sample_idx,
state_row=_state_row_at(state_all, 0),
)
]
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = {
**complementary,
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
}
return transition
def _encode_messages(
self,
tokenizer: Any,
messages: list[dict[str, Any]],
message_streams: list[str | None],
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
state_row: Any = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
# Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping
# after removal.
if (
self.plan_dropout_prob
or self.memory_dropout_prob
or self.subtask_dropout_prob
or self.interjection_dropout_prob
):
messages, target_indices = self._apply_prompt_dropout(
messages,
target_indices,
complementary,
sample_idx=sample_idx,
)
# Rewrite bbox / keypoint VQA target answers from JSON to
# PaliGemma <loc> text. Coords are 01000 normalized so this is
# camera-independent.
messages = _messages_vqa_to_loc(messages, target_indices)
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
# stripping, so the spoken reply is actually tokenized and
# supervised (PaliGemma's flat prompt has no structured calls).
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
# Low-level (action-conditioning) samples get the discretized state
# appended to their user message, mirroring pi05's
# "..., State: {256-bin};" so the action expert sees proprioception.
# Higher-level text streams (subtask/memory generation) stay state-free.
if state_row is not None and any(s == "low_level" for s in message_streams):
state_str = discretize_state_str(state_row)
for m in reversed(messages):
if m.get("role") == "user":
base = _content_to_text(m.get("content", ""))
m["content"] = f"{base}, State: {state_str};"
break
# Append EOS to supervised target turns so the LM head learns to
# stop (the span covers it → it becomes a supervised label).
prompt, spans = _format_messages(
messages, target_indices, getattr(tokenizer, "eos_token", None)
)
encoded = tokenizer(
prompt,
max_length=self.max_length,
padding=self.padding,
truncation=True,
return_tensors="pt",
return_offsets_mapping=True,
padding_side=self.padding_side,
)
input_ids = encoded["input_ids"][0]
attention_mask = encoded["attention_mask"][0].bool()
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
# Build label mask: -100 everywhere except over supervised
# target message char ranges.
labels = torch.full_like(input_ids, fill_value=-100)
for idx in target_indices:
if idx >= len(spans):
continue
char_start, char_end = spans[idx]
for token_pos in range(input_ids.shape[0]):
if not attention_mask[token_pos]:
continue
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
if tok_end <= char_start or tok_start >= char_end:
continue
labels[token_pos] = input_ids[token_pos]
# Scan ALL message streams (not just targets): the
# ``low_level_execution`` recipe drops ``target: true`` on
# the assistant to avoid trivial copy-from-user text-CE; the
# flow loss still needs to fire, gated by ``stream: low_level``.
predict_actions = torch.tensor(
bool(any(s == "low_level" for s in message_streams)),
dtype=torch.bool,
)
return input_ids, attention_mask, labels, predict_actions, prompt
# ------------------------------------------------------------------
# Per-component prompt dropout (Pi0.7 §V.E)
# ------------------------------------------------------------------
def _apply_prompt_dropout(
self,
messages: list[dict[str, Any]],
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
) -> tuple[list[dict[str, Any]], list[int]]:
"""Drop messages classified as plan/memory/subtask context.
Targets are *never* dropped (they're the supervised payload).
Re-maps target_indices to the new positions after drops.
"""
import random # noqa: PLC0415
seed = self.dropout_seed
if seed is None:
# Canonical row-index key set by ``BatchProcessor`` /
# ``render_messages_processor``. Falling back to other
# keys silently gave every sample seed=0 → identical
# dropout pattern across the whole epoch.
seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
try:
if hasattr(seed_src, "item"):
seed_src = seed_src.item()
seed = int(seed_src)
except (TypeError, ValueError):
seed = 0
rng = random.Random(seed)
keep_indices: list[int] = []
for idx, msg in enumerate(messages):
if idx in target_indices:
keep_indices.append(idx)
continue
kind = _classify_for_dropout(msg)
prob = {
"plan": self.plan_dropout_prob,
"memory": self.memory_dropout_prob,
"subtask": self.subtask_dropout_prob,
"interjection": self.interjection_dropout_prob,
}.get(kind, 0.0)
if prob > 0.0 and rng.random() < prob:
continue
keep_indices.append(idx)
# Build remap and apply
new_messages = [messages[i] for i in keep_indices]
old_to_new = {old: new for new, old in enumerate(keep_indices)}
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
return new_messages, new_targets
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
content = message.get("content")
if isinstance(content, list):
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
content = " ".join(text_parts)
elif content is None:
return None
elif not isinstance(content, str):
return None
s = content.strip()
if s.startswith("Plan:") or s.startswith("Previous plan"):
return "plan"
if s.startswith("Memory:") or s.startswith("Previous memory"):
return "memory"
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
return "subtask"
return None
+387 -2
View File
@@ -14,18 +14,28 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import copy
from typing import TYPE_CHECKING, Literal
import torch
from torch import nn
from torch import Tensor, nn
from torch.nn import functional as F # noqa: N812
from lerobot.utils.import_utils import _transformers_available
# Default PaliGemma SigLIP input resolution. Mirrors
# ``pi05.configuration_pi05.DEFAULT_IMAGE_SIZE``; duplicated as a plain constant
# to avoid importing the pi05 package here (which would create an import cycle:
# pi_gemma -> pi05.__init__ -> modeling_pi05 -> pi_gemma).
DEFAULT_IMAGE_SIZE = 224
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaConfig,
@@ -49,6 +59,8 @@ else:
GradientCheckpointingLayer = None
BaseModelOutputWithPast = None
create_causal_mask = None
CONFIG_MAPPING = None
modeling_gemma = None
def _gated_residual(
@@ -275,6 +287,8 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
if causal_mask is not None and torch.is_floating_point(causal_mask):
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
@@ -367,3 +381,374 @@ __all__ = [
"PaliGemmaModelWithPiGemma",
"PaliGemmaForConditionalGenerationWithPiGemma",
]
# PI0.5 / PI052 dual-expert backbone: generic PaliGemma + Gemma action-expert
# transformer machinery used by the pi052 policy. GemmaVariantConfig is openpi's
# width/depth variant config (renamed from GemmaConfig to avoid clashing with
# transformers' GemmaConfig).
def sdpa_attention_forward(
module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
):
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
``torch.nn.functional.scaled_dot_product_attention``.
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
bias masks (the FA backend only accepts causal/sliding-window). On
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
"""
n_rep = module.num_key_value_groups
if n_rep > 1:
key = key.repeat_interleave(n_rep, dim=1)
value = value.repeat_interleave(n_rep, dim=1)
if attention_mask is not None and attention_mask.dtype != query.dtype:
attention_mask = attention_mask.to(dtype=query.dtype)
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout if module.training else 0.0,
is_causal=False,
scale=scaling,
)
return attn_output.transpose(1, 2).contiguous(), None
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
att_output, _ = sdpa_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
class GemmaVariantConfig: # see openpi `gemma.py: Config`
"""Configuration for Gemma model variants."""
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
self.width = width
self.depth = depth
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
def get_gemma_config(variant: str) -> GemmaVariantConfig: # see openpi `gemma.py: get_config`
"""Returns config for specified gemma variant."""
if variant == "gemma_300m":
return GemmaVariantConfig(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
elif variant == "gemma_2b":
return GemmaVariantConfig(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
else:
raise ValueError(f"Unknown variant: {variant}")
class PaliGemmaWithExpertModel(
nn.Module
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
"""PaliGemma model with action expert for PI05."""
def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
vlm_config_hf.image_token_index = 257152
vlm_config_hf.text_config.hidden_size = vlm_config.width
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.image_size = image_size
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
hidden_size=action_expert_config.width,
intermediate_size=action_expert_config.mlp_dim,
num_attention_heads=action_expert_config.num_heads,
num_hidden_layers=action_expert_config.depth,
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
self._set_requires_grad()
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
self.to(dtype=torch.float32)
return
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"lm_head",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_keep_float32):
param.data = param.data.to(dtype=torch.float32)
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
for param in self.paligemma.parameters():
param.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
# OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the
# Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base
# was trained in this regime, so scaling image features here over-scales them ~45x and
# breaks the pretrained vision-language alignment. Keep image features un-normalized.
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
inputs_embeds: list[torch.FloatTensor] | None = None,
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
)
prefix_past_key_values = prefix_output.past_key_values
prefix_output = prefix_output.last_hidden_state
suffix_output = None
elif inputs_embeds[0] is None:
suffix_output = self.gemma_expert.model.forward(
inputs_embeds=inputs_embeds[1],
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
)
suffix_output = suffix_output.last_hidden_state
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
hasattr(self.gemma_expert.model, "gradient_checkpointing")
and self.gemma_expert.model.gradient_checkpointing
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
)
# final norm
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
# Apply gradient checkpointing to final norm if enabled
if use_gradient_checkpointing:
outputs_embeds = torch.utils.checkpoint.checkpoint(
compute_final_norms,
inputs_embeds,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
)
else:
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
prefix_output = outputs_embeds[0]
suffix_output = outputs_embeds[1]
prefix_past_key_values = None
return [prefix_output, suffix_output], prefix_past_key_values
+5 -85
View File
@@ -29,7 +29,6 @@ from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.__version__ import __version__
from lerobot.configs import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.utils.hub import HubMixin
@@ -39,67 +38,6 @@ from .utils import log_model_loading_keys
T = TypeVar("T", bound="PreTrainedPolicy")
def _build_card_context(
cfg: TrainPipelineConfig | None,
dataset_repo_id: str | None,
input_features: dict | None,
output_features: dict | None,
) -> dict:
"""Collect optional data for the model-card template.
Returns plain values only (no Markdown) the template in
``lerobot/templates/lerobot_modelcard_template.md`` decides how and whether to show
each one. Everything is best-effort: anything unavailable is left empty/None and the
template simply skips that section, so this never breaks a Hub push.
"""
context = {
"training": None,
"input_features": input_features or {},
"output_features": output_features or {},
"dataset": None,
"robot_type": None,
"cameras": [],
}
if cfg is not None:
optimizer = getattr(cfg, "optimizer", None)
context["training"] = {
"steps": cfg.steps,
"batch_size": cfg.batch_size,
"seed": cfg.seed,
"optimizer": getattr(optimizer, "type", None) if optimizer else None,
"lr": getattr(optimizer, "lr", None) if optimizer else None,
"lerobot_version": __version__,
}
if dataset_repo_id:
dataset_cfg = getattr(cfg, "dataset", None)
try:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(
dataset_repo_id,
root=getattr(dataset_cfg, "root", None),
revision=getattr(dataset_cfg, "revision", None),
)
context["dataset"] = {
"repo_id": dataset_repo_id,
"episodes": meta.total_episodes,
"frames": meta.total_frames,
"fps": meta.fps,
"tasks": [str(task) for task in meta.tasks.index],
}
context["robot_type"] = meta.robot_type
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
logging.warning(
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
f"omitted from the model card. ({e})"
)
return context
class ActionSelectKwargs(TypedDict, total=False):
noise: Tensor | None
@@ -290,7 +228,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
)
card.save(str(saved_path / "README.md"))
@@ -308,20 +246,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
logging.info(f"Model pushed to {commit_info.repo_url.url}")
def generate_model_card(
self,
dataset_repo_id: str,
model_type: str,
license: str | None,
tags: list[str] | None,
cfg: TrainPipelineConfig | None = None,
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
) -> ModelCard:
base_model_mapping = {
"smolvla": "lerobot/smolvla_base",
"pi0": "lerobot/pi0_base",
"pi05": "lerobot/pi05_base",
"pi0_fast": "lerobot/pi0fast-base",
"xvla": "lerobot/xvla-base",
}
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
card_data = ModelCardData(
license=license or "apache-2.0",
@@ -330,20 +257,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
model_name=model_type,
datasets=dataset_repo_id,
base_model=base_model_mapping.get(model_type),
base_model=base_model,
)
context = _build_card_context(
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
)
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
context["base_model"] = base_model_mapping.get(model_type)
template_card = (
files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8")
)
card = ModelCard.from_template(card_data, template_str=template_card, **context)
card = ModelCard.from_template(card_data, template_str=template_card)
card.validate()
return card
-3
View File
@@ -175,9 +175,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
+54 -278
View File
@@ -32,6 +32,7 @@ from __future__ import annotations
import importlib
import json
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
@@ -280,11 +281,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
_serialized_state_filenames: tuple[str | None, ...] | None = field(
default=None,
init=False,
repr=False,
)
def __call__(self, data: TInput) -> TOutput:
"""Processes input data through the full pipeline.
@@ -342,108 +338,30 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
transition = processor_step(transition)
yield transition
def _get_sanitized_name(self) -> str:
"""Return a filename-safe version of the pipeline name.
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `HubMixin`'s saving mechanism.
Returns:
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
config_filename = kwargs.pop("config_filename", None)
@staticmethod
def _get_state_filename(
*,
step_index: int,
registry_name: str | None,
sanitized_name: str,
) -> str:
"""Return the safetensors filename for one stateful processor step.
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
Args:
step_index: The index of the processor step in this pipeline.
registry_name: The registered processor step name, if available.
sanitized_name: The filename-safe pipeline name.
if config_filename is None:
config_filename = f"{sanitized_name}.json"
Returns:
The state filename used by the existing disk serialization format.
"""
if registry_name:
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
return f"{sanitized_name}_step_{step_index}.safetensors"
@staticmethod
def _get_state_key(state_filename: str) -> str:
"""Return the in-memory state key for a serialized state filename.
Args:
state_filename: The `.safetensors` filename from the serialized config.
Returns:
The state key used by the in-memory pipeline state dictionary.
"""
return state_filename.removesuffix(".safetensors")
@staticmethod
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
"""Return serialized state filenames in step order.
Args:
loaded_config: A validated processor pipeline config.
Returns:
A tuple containing each step's serialized state filename, or None for stateless steps.
"""
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
"""Return expected state filenames in step order for `load_state_dict()`.
Returns:
The preserved serialized state filenames when available, otherwise filenames derived from
current non-empty step state.
"""
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
self.steps
):
return self._serialized_state_filenames
sanitized_name = self._get_sanitized_name()
state_filenames: list[str | None] = []
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
state_filenames.append(None)
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filenames.append(
self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
)
return tuple(state_filenames)
def get_config(self) -> dict[str, Any]:
"""Return the JSON-serializable pipeline configuration.
Returns:
A dictionary with the same content that `save_pretrained()` writes as JSON.
"""
sanitized_name = self._get_sanitized_name()
pipeline_config: dict[str, Any] = {
config: dict[str, Any] = {
"name": self.name,
"steps": [],
}
# Iterate through each step to build its configuration entry.
for step_index, processor_step in enumerate(self.steps):
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
step_entry: dict[str, Any] = {}
# Prefer registry name for portability, otherwise fall back to full class path.
if registry_name:
step_entry["registry_name"] = registry_name
else:
@@ -451,110 +369,31 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
step_entry["config"] = processor_step.get_config()
# Save step configuration if `get_config` is implemented.
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
step_state_dict = processor_step.state_dict()
if step_state_dict:
step_entry["state_file"] = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
# Save step state if `state_dict` is implemented and returns a non-empty dict.
if hasattr(processor_step, "state_dict"):
state = processor_step.state_dict()
if state:
# Clone tensors to avoid modifying the original state.
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
pipeline_config["steps"].append(step_entry)
# Create a unique filename for the state file.
if registry_name:
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
return pipeline_config
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
"""Return pipeline state tensors grouped by state key.
config["steps"].append(step_entry)
Returns:
A dictionary mapping suffixless state keys to cloned step state dictionaries.
"""
sanitized_name = self._get_sanitized_name()
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filename = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
state_key = self._get_state_key(state_filename)
pipeline_state_dict[state_key] = {
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
}
return pipeline_state_dict
def load_state_dict(
self,
state_dict: dict[str, dict[str, torch.Tensor]],
) -> None:
"""Load pipeline state tensors into the existing steps.
Args:
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
Raises:
KeyError: If loading finds missing expected state or unexpected extra state.
"""
expected_state_filenames = self._get_state_filenames_for_loading()
used_state_keys: set[str] = set()
for step_index, (processor_step, state_filename) in enumerate(
zip(self.steps, expected_state_filenames, strict=True)
):
if state_filename is None:
continue
state_key = self._get_state_key(state_filename)
if state_key not in state_dict:
raise KeyError(
f"Missing state key '{state_key}' for processor step {step_index}. "
f"Available state keys: {sorted(state_dict.keys())}"
)
processor_step.load_state_dict(state_dict[state_key])
used_state_keys.add(state_key)
unexpected_state_keys = set(state_dict) - used_state_keys
if unexpected_state_keys:
expected_state_key_set = {
self._get_state_key(state_filename)
for state_filename in expected_state_filenames
if state_filename is not None
}
raise KeyError(
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
f"Expected state keys: {sorted(expected_state_key_set)}"
)
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
"""Internal method to comply with `HubMixin`'s saving mechanism.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
config_filename = kwargs.pop("config_filename", None)
sanitized_name = self._get_sanitized_name()
if config_filename is None:
config_filename = f"{sanitized_name}.json"
pipeline_config = self.get_config()
pipeline_state_dict = self.state_dict()
for state_key, step_state_dict in pipeline_state_dict.items():
state_filename = f"{state_key}.safetensors"
save_file(step_state_dict, save_directory / state_filename)
with open(save_directory / config_filename, "w") as file_pointer:
json.dump(pipeline_config, file_pointer, indent=2)
# Write the main configuration JSON file.
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
def save_pretrained(
self,
@@ -738,54 +577,12 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
cls._validate_overrides_used(validated_overrides, loaded_config)
# 5. Construct and return the final pipeline instance
pipeline = cls(
return cls(
steps=steps,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
return pipeline
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[TInput], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
) -> DataProcessorPipeline[TInput, TOutput]:
"""Build a pipeline from an in-memory config and optional state tensors.
Args:
config: A config dictionary with the same structure as the saved processor JSON.
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
overrides: Optional constructor overrides keyed by registry name or class name.
to_transition: Optional converter from input data to `EnvTransition`.
to_output: Optional converter from `EnvTransition` to output data.
Returns:
A processor pipeline built from the config and optional state.
"""
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
cls._validate_overrides_used(remaining_override_keys, config)
pipeline = cls(
steps=steps,
name=config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
if state_dict is not None:
pipeline.load_state_dict(state_dict)
return pipeline
@classmethod
def _load_config(
@@ -869,7 +666,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
) from e
@classmethod
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
def _validate_loaded_config(
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
) -> None:
"""Validate that a config was loaded and is a valid processor config.
This method validates processor config format with intelligent migration detection:
@@ -889,7 +688,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Args:
model_id: The model identifier (used for migration detection)
loaded_config: The loaded config value to validate (may be non-dict)
loaded_config: The loaded config dictionary (guaranteed non-None)
config_filename: The config filename that was loaded (for error messages)
Raises:
@@ -903,14 +702,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
model_id,
f"Config file '{config_filename}' is not a valid processor configuration",
)
loaded_config_description = (
list(loaded_config.keys())
if isinstance(loaded_config, dict)
else type(loaded_config).__name__
)
raise ValueError(
f"Config file '{config_filename}' is not a valid processor configuration. "
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
)
@classmethod
@@ -972,41 +766,26 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
ImportError: If a step class cannot be imported or found in registry
ValueError: If a step cannot be instantiated with its configuration
"""
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
return steps, remaining_override_keys
@classmethod
def _build_steps_from_config(
cls,
loaded_config: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[list[ProcessorStep], set[str]]:
"""Build processor steps from config without loading tensor state.
Args:
loaded_config: The loaded processor configuration.
overrides: User-provided constructor overrides keyed by step key.
Returns:
A tuple containing instantiated steps and override keys that did not match a step.
"""
processor_steps: list[ProcessorStep] = []
remaining_override_keys = set(overrides.keys())
steps: list[ProcessorStep] = []
override_keys = set(overrides.keys())
for step_entry in loaded_config["steps"]:
# 1. Get step class and key
step_class, step_key = cls._resolve_step_class(step_entry)
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
if step_key in remaining_override_keys:
remaining_override_keys.discard(step_key)
# 2. Instantiate step with overrides
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
processor_steps.append(processor_step)
# 3. Load step state if available
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
return processor_steps, remaining_override_keys
# 4. Track used overrides
if step_key in override_keys:
override_keys.discard(step_key)
steps.append(step_instance)
return steps, override_keys
@classmethod
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
@@ -1317,7 +1096,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
return True
@classmethod
def _is_processor_config(cls, config: Any) -> bool:
def _is_processor_config(cls, config: dict) -> bool:
"""Check if config follows DataProcessorPipeline format.
This method validates the processor configuration structure:
@@ -1368,9 +1147,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Returns:
True if config follows valid DataProcessorPipeline format, False otherwise
"""
if not isinstance(config, dict):
return False
# Must have a "steps" field with a list of step configurations
if not isinstance(config.get("steps"), list):
return False
@@ -50,7 +50,17 @@ class RenderMessagesStep(ProcessorStep):
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
rendered = _fallback_low_level_render(complementary_data.get("task"))
if rendered is None:
return transition
new_transition = transition.copy()
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
if _is_batched_language(persistent) or _is_batched_language(events):
return self._call_batch(transition, complementary_data, persistent, events)
timestamp = complementary_data.get("timestamp")
if timestamp is None:
@@ -67,18 +77,147 @@ class RenderMessagesStep(ProcessorStep):
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
rendered = _fallback_low_level_render(complementary_data.get("task"))
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def _call_batch(
self,
transition: EnvTransition,
complementary_data: dict[str, Any],
persistent_batch: list,
events_batch: list,
) -> EnvTransition | None:
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
batch_size = max(len(persistent_batch), len(events_batch))
messages: list[list[dict[str, Any]]] = []
message_streams: list[list[str | None]] = []
target_message_indices: list[list[int]] = []
keep_indices: list[int] = []
for i in range(batch_size):
rendered = render_sample(
recipe=self.recipe,
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
events=events_batch[i] if i < len(events_batch) else [],
t=_batch_value(timestamp, i),
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
task=_batch_value(complementary_data.get("task"), i),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
if rendered is None:
continue
keep_indices.append(i)
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
if not messages:
return None
new_transition = (
_select_batch_indices(transition, keep_indices)
if len(keep_indices) != batch_size
else transition.copy()
)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data["messages"] = messages
new_complementary_data["message_streams"] = message_streams
new_complementary_data["target_message_indices"] = target_message_indices
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return _scalar(value[0])
return value
def _is_batched_language(value: Any) -> bool:
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
def _batch_value(value: Any, index: int) -> Any:
if value is None:
return None
if isinstance(value, list):
return value[index]
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
return _scalar(value[index])
return _scalar(value)
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
selected = transition.copy()
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
data = selected.get(key)
if isinstance(data, dict):
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
action = selected.get(TransitionKey.ACTION)
if action is not None:
selected[TransitionKey.ACTION] = _select_value(action, indices)
return selected
def _select_value(value: Any, indices: list[int]) -> Any:
if isinstance(value, list) and len(value) >= len(indices):
return [value[i] for i in indices]
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
return value.index_select(0, value.new_tensor(indices).long())
return value
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
"""Keep action-only samples trainable when no recipe branch matches."""
if hasattr(task, "item"):
task = task.item()
if isinstance(task, list):
messages = []
message_streams = []
target_message_indices = []
for t in task:
rendered = _fallback_low_level_render(t)
if rendered is None:
return None
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
return {
"messages": messages,
"message_streams": message_streams,
"target_message_indices": target_message_indices,
}
if not isinstance(task, str) or not task:
return None
return {
"messages": [{"role": "user", "content": task}],
"message_streams": ["low_level"],
"target_message_indices": [],
}
+31 -16
View File
@@ -32,6 +32,7 @@ import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
from lerobot.utils.constants import (
ACTION_CODE_TOKEN_MASK,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -412,14 +413,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# During inference, no action is available, skip tokenization
return new_transition
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
tokens, mask, code_mask = self._tokenize_action(action)
# Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None:
complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
@@ -430,7 +432,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
@@ -459,6 +461,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# The fast tokenizer expects action data and returns token IDs
tokens_list = []
masks_list = []
code_masks_list = []
for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
@@ -476,19 +479,26 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
if tokens.dim() > 1:
tokens = tokens.flatten()
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
bos_id = self._paligemma_tokenizer.bos_token_id
# add bos
prompt_tokens = torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
)
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
code_start = 1 + len(prompt_tokens)
code_end = code_start + len(action_code_tokens)
tokens = torch.cat(
[
torch.tensor([bos_id], device=action.device),
torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
prompt_tokens,
action_code_tokens,
end_tokens,
]
)
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
code_mask[code_start:code_end] = True
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
@@ -497,44 +507,49 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
)
tokens = tokens[: self.max_action_tokens]
code_mask = code_mask[: self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
pad_len = self.max_action_tokens - len(tokens)
mask = torch.cat(
[
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
),
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
]
)
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
# Pad tokens with zeros
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
tokens_list.append(tokens)
masks_list.append(mask)
code_masks_list.append(code_mask)
# Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample
if single_sample:
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)
code_masks_batch = code_masks_batch.squeeze(0)
# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
code_masks_batch = code_masks_batch.to(device)
return tokens_batch, masks_batch
return tokens_batch, masks_batch, code_masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
This method is not used since we override __call__.
Required by ActionProcessorStep ABC.
"""
tokens, _ = self._tokenize_action(action)
tokens, _, _ = self._tokenize_action(action)
return tokens
def get_config(self) -> dict[str, Any]:
+3 -1
View File
@@ -21,6 +21,8 @@ from lerobot.utils.import_utils import make_device_from_device_class
from .config import RobotConfig
from .robot import Robot
logger = logging.getLogger(__name__)
def make_robot_from_config(config: RobotConfig) -> Robot:
# TODO(Steven): Consider just using the make_device_from_device_class for all types
@@ -118,7 +120,7 @@ def ensure_safe_goal_position(
}
if warnings_dict:
logging.warning(
logger.warning(
"Relative goal position magnitude had to be clamped to be safe.\n"
f"{pformat(warnings_dict, indent=4)}"
)
+1 -6
View File
@@ -175,17 +175,12 @@ def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
"repo_id": repo_id,
"tag": version_tag,
"repo_type": "dataset",
"exist_ok": True,
}
if revision is not None:
tag_kwargs["revision"] = revision
try:
from contextlib import suppress # noqa: PLC0415
from huggingface_hub.errors import RevisionNotFoundError # noqa: PLC0415
with suppress(RevisionNotFoundError):
api.delete_tag(repo_id, tag=version_tag, repo_type="dataset")
api.create_tag(**tag_kwargs)
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
except Exception as exc: # noqa: BLE001
@@ -94,14 +94,6 @@ Merge multiple datasets from a list of local dataset paths:
--operation.repo_ids "['pusht_train', 'pusht_val']" \
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
Merge multiple datasets while keeping one file per source file (no video/data stitching):
lerobot-edit-dataset \
--new_repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" \
--operation.concatenate_videos false \
--operation.concatenate_data false
Remove camera feature:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
@@ -265,9 +257,6 @@ class SplitConfig(OperationConfig):
class MergeConfig(OperationConfig):
repo_ids: list[str] | None = None
roots: list[str] | None = None
# When False, keep one file per source file instead of packing into shards.
concatenate_videos: bool = True
concatenate_data: bool = True
@OperationConfig.register_subclass("remove_feature")
@@ -472,8 +461,6 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
datasets,
output_repo_id=cfg.new_repo_id,
output_dir=output_dir,
concatenate_videos=cfg.operation.concatenate_videos,
concatenate_data=cfg.operation.concatenate_data,
)
logging.info(f"Merged dataset saved to {output_dir}")
+94 -2
View File
@@ -95,6 +95,67 @@ from lerobot.utils.utils import (
)
def _wrap_text_to_width(text: str, cv2, font, scale: int, thickness: int, max_width: int) -> list[str]:
"""Greedy word-wrap using measured pixel width so text fits the frame."""
words = text.split()
lines: list[str] = []
current = ""
for word in words:
candidate = f"{current} {word}".strip()
(w, _), _ = cv2.getTextSize(candidate, font, scale, thickness)
if w > max_width and current:
lines.append(current)
current = word
else:
current = candidate
if current:
lines.append(current)
return lines or [""]
def _annotate_eval_frames(
frames: np.ndarray, task: str | None, subtask: str | None
) -> np.ndarray:
"""Overlay the high-level task and predicted subtask onto rendered frames.
``frames`` is ``(n_envs, H, W, C)`` uint8. Best-effort: if OpenCV isn't
available the frames are returned unchanged so eval never fails over a
visualization concern.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
return frames
try:
import cv2 # noqa: PLC0415
except ImportError:
return frames
width = frames.shape[2]
font = cv2.FONT_HERSHEY_SIMPLEX
scale = 0.5
margin = 6
max_width = width - 2 * margin
lines: list[str] = []
if task:
lines += _wrap_text_to_width(f"Task: {task}", cv2, font, scale, 1, max_width)
if subtask:
lines += _wrap_text_to_width(f"Subtask: {subtask}", cv2, font, scale, 1, max_width)
if not lines:
return frames
out = frames.copy()
for i in range(out.shape[0]):
img = np.ascontiguousarray(out[i])
y = 18
for line in lines:
# Black outline then white fill so text stays legible on any scene.
cv2.putText(img, line, (margin, y), font, scale, (0, 0, 0), 3, cv2.LINE_AA)
cv2.putText(img, line, (margin, y), font, scale, (255, 255, 255), 1, cv2.LINE_AA)
y += 20
out[i] = img
return out
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
@@ -325,11 +386,42 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
frames = np.stack([env.envs[i].render() for i in range(n_to_render_now)]) # noqa: B023
elif hasattr(env, "call"):
# Here we must render all frames and discard any we don't need.
# Covers AsyncVectorEnv and _LazyAsyncVectorEnv (which wraps one).
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
frames = np.stack(env.call("render")[:n_to_render_now])
else:
return
# Overlay the high-level task and (for hierarchical policies like
# pi052) the predicted low-level subtask onto each frame. Both are
# best-effort: missing values just skip that line.
try:
tasks = list(env.call("task_description"))
except (AttributeError, NotImplementedError):
try:
tasks = list(env.call("task"))
except (AttributeError, NotImplementedError):
tasks = None
# Per-env subtasks when available (batched hierarchical policies);
# fall back to the scalar last_subtask for single-env / other policies.
subtasks = getattr(policy, "last_subtasks", None)
subtask_scalar = getattr(policy, "last_subtask", None)
annotated = []
for i in range(frames.shape[0]):
if subtasks is not None and i < len(subtasks):
subtask_i = subtasks[i]
else:
subtask_i = subtask_scalar
annotated.append(
_annotate_eval_frames(
frames[i : i + 1],
tasks[i] if tasks is not None and i < len(tasks) else None,
subtask_i,
)[0]
)
ep_frames.append(np.stack(annotated))
if max_episodes_rendered > 0:
video_paths: list[str] = []
File diff suppressed because it is too large Load Diff
+392 -143
View File
@@ -20,6 +20,7 @@ Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wand
import dataclasses
import logging
import os
import time
from contextlib import nullcontext
from pprint import pformat
@@ -36,8 +37,6 @@ from tqdm import tqdm
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
save_checkpoint,
update_last_checkpoint,
@@ -45,8 +44,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state
from lerobot.datasets.factory import make_train_eval_datasets
from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -102,9 +100,6 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Compute sample weights if a weighter is provided
sample_weights = None
weight_stats = None
@@ -164,11 +159,173 @@ def update_policy(
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
if torch.cuda.is_available():
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
return train_metrics, output_dict
def _print_debug_text_predictions(
policy: Any, batch: dict[str, Any], step: int, n_samples: int = 5
) -> None:
"""Forward the current batch and print head-argmax vs label per supervised position.
Opt-in via ``LEROBOT_DEBUG_PREDS_EVERY=<step_interval>``. Only the
policy types that expose ``debug_text_predictions`` participate
(currently PI052); others are silently skipped. Pretty-prints up to
``n_samples`` samples from the current batch, showing the prompt,
every supervised position's (label, prediction, ✓/✗), and a
per-sample token-accuracy summary the cheapest "is text training
actually learning anything" signal.
"""
# Accelerator/DDP wraps the policy in a ``module`` attribute and
# doesn't proxy custom methods through, so a naive
# ``hasattr(policy, "debug_text_predictions")`` returns False on the
# wrapper — and the helper would silently no-op. Walk through any
# ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers)
# to reach the raw policy that actually defines the method.
inner = policy
while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"):
inner = inner.module
if not hasattr(inner, "debug_text_predictions"):
logging.warning(
"LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no "
"debug_text_predictions method — skipping dump.",
type(inner).__name__,
)
return
try:
debug = inner.debug_text_predictions(batch, max_samples=n_samples)
except Exception as exc: # noqa: BLE001
logging.warning("debug_text_predictions failed: %s", exc, exc_info=True)
return
if not debug:
logging.warning(
"debug_text_predictions returned no supervised samples — "
"current batch has no text labels."
)
return
policy = inner # used below for select_message-style decoding parity
# Build a tokenizer for decoding — match training side exactly.
try:
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
except Exception as exc: # noqa: BLE001
logging.warning("debug preds: tokenizer load failed: %s", exc)
return
ids = debug["input_ids"]
labels = debug["labels"]
preds = debug["predictions"]
attn = debug["attention_mask"]
n = ids.shape[0]
print(
f"\n========== STEP {step} DEBUG PREDICTIONS ({n} samples) ==========",
flush=True,
)
for s in range(n):
a = attn[s].tolist()
real = sum(a)
sid = ids[s].tolist()
sl = labels[s].tolist()
sp = preds[s].tolist()
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
print(f" prompt: {prompt!r}", flush=True)
# Ground-truth target (the contiguous supervised label span).
sup_ids = [int(sid[i]) for i in range(real) if sl[i] != -100]
if sup_ids:
print(
f" target (ground truth) : {tokenizer.decode(sup_ids, skip_special_tokens=False)!r}",
flush=True,
)
# Training-side teacher-forced argmax on the same prompt+target.
n_sup = n_ok = 0
teacher_chars: list[int] = []
for i in range(1, real):
label = sl[i]
if label == -100:
continue
n_sup += 1
pred = int(sp[i - 1])
teacher_chars.append(pred)
if label == pred:
n_ok += 1
teacher_text = (
tokenizer.decode(teacher_chars, skip_special_tokens=False) if teacher_chars else ""
)
acc = n_ok / max(n_sup, 1)
print(
f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}",
flush=True,
)
print("=" * 60 + "\n", flush=True)
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
"""Build per-frame sampling weights that oversample VQA-annotated frames.
Scans the dataset's ``language_events`` column for frames carrying a
``vqa``-style annotation and returns a weight tensor (length == total
dataset frames) such that, under multinomial sampling, VQA frames make up
roughly ``target_fraction`` of the training stream.
Returns ``None`` ( fall back to uniform episode-aware sampling) when VQA
frames cannot be detected or there are none.
"""
if not 0.0 < target_fraction < 1.0:
logging.warning(
"vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.",
target_fraction,
)
return None
hf = getattr(dataset, "hf_dataset", None)
if hf is None or "language_events" not in getattr(hf, "column_names", []):
logging.warning(
"Dataset has no `language_events` column — VQA oversampling disabled."
)
return None
events_col = hf["language_events"]
n_frames = len(events_col)
is_vqa = torch.zeros(n_frames, dtype=torch.bool)
for i, rows in enumerate(events_col):
if rows and any((row or {}).get("style") == "vqa" for row in rows):
is_vqa[i] = True
n_vqa = int(is_vqa.sum())
if n_vqa == 0:
logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.")
return None
n_other = n_frames - n_vqa
# Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w.
# Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform.
weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1))
weight = max(weight, 1.0)
weights = torch.ones(n_frames, dtype=torch.double)
weights[is_vqa] = weight
logging.info(
"VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); "
"weighting them x%.2f to target ~%.0f%% of the training stream.",
n_vqa,
n_frames,
100.0 * n_vqa / n_frames,
weight,
100.0 * target_fraction,
)
return weights
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
"""
@@ -198,15 +355,26 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
from datetime import timedelta
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Bump the c10d store-get / barrier timeout so the rank-0-only
# ``make_dataset`` block below doesn't trigger a barrier crash on
# large datasets. Default is 10 min (``store->get`` 600 s); a
# 32 k-episode v3 dataset (e.g. ``robocasa_pretrain_human300_v4``)
# spends >13 min on rank 0 building the episode/frame index
# while ranks 1-N idle at ``wait_for_everyone()`` and crash with
# ``DistBackendError: ... wait timeout after 600000ms``. 2 h is
# plenty of headroom; fast paths are unaffected.
ipg_kwargs = InitProcessGroupKwargs(timeout=timedelta(hours=2))
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
force_cpu = cfg.trainable_config.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs],
kwargs_handlers=[ddp_kwargs, ipg_kwargs],
cpu=force_cpu,
)
@@ -240,24 +408,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: the global main process downloads once to the shared
# dataset root, then a barrier lets every other rank read the already-populated copy.
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
dataset, eval_dataset = make_train_eval_datasets(cfg)
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Other ranks read from the shared copy populated by the main process.
# Now all other processes can safely load the dataset
if not is_main_process:
dataset, eval_dataset = make_train_eval_datasets(cfg)
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.env_eval_freq > 0 and cfg.env is not None and is_main_process:
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
@@ -302,6 +468,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
active_cfg = cfg.trainable_config
processor_pretrained_path = active_cfg.pretrained_path
# pi052: even when loading pretrained weights, build the processors
# from the current pi052 config so the recipe text-label and FAST
# action-label steps are generated and not silently swapped for the
# checkpoint's older processor stack.
if cfg.policy.type == "pi052" and processor_pretrained_path is not None and not cfg.resume:
logging.warning(
"pi052 is loading pretrained weights from %s, but building processors from the current "
"pi052 config so recipe text labels and FAST action labels are generated.",
processor_pretrained_path,
)
processor_pretrained_path = None
if (
getattr(active_cfg, "use_relative_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
logging.warning(
"use_relative_actions=true with pretrained processors can skip relative transforms if "
"the checkpoint processors do not define them. Building processors from current policy config."
)
processor_pretrained_path = None
processor_kwargs = {}
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
@@ -310,6 +497,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if cfg.is_reward_model_training:
processor_kwargs["dataset_meta"] = dataset.meta
# For pi052 (and any future policy that auto-fits part of its
# preprocessing per-dataset), pass the dataset repo id so the
# processor factory can locate/refresh dataset-specific artifacts
# (e.g. fitted FAST tokenizers per Pertsch et al. 2025 [64],
# π0.5 §III.C).
if cfg.policy.type == "pi052":
processor_kwargs["dataset_repo_id"] = cfg.dataset.repo_id
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
preprocessor_overrides = {
"device_processor": {"device": device.type},
@@ -394,47 +589,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if not cfg.dataset.streaming:
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
# The order is a pure function of (seed, epoch), so every rank independently produces the
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
# use the values recorded in the checkpoint (falling back to the current ones for older
# ckpts that did not store them).
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
ckpt_num_processes = saved_num_processes or accelerator.num_processes
ckpt_batch_size = saved_batch_size or cfg.batch_size
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
logging.warning(
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
f"written with num_processes={saved_num_processes}. The data order resumes at the "
"right epoch/offset, but per-rank sample-exactness requires the same world size."
)
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
logging.warning(
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
"but per-rank sample-exactness requires the same batch size."
)
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
sampler.load_state_dict(sampler_state)
if is_main_process:
logging.info(
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
from_indices = dataset.meta.episodes["dataset_from_index"]
to_indices = dataset.meta.episodes["dataset_to_index"]
# When `vqa_target_fraction` is set, oversample VQA-annotated
# frames via a weighted sampler; otherwise plain episode-aware.
vqa_weights = None
if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming:
vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction)
if vqa_weights is not None:
sampler = WeightedEpisodeAwareSampler(
from_indices,
to_indices,
vqa_weights,
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
)
else:
sampler = EpisodeAwareSampler(
from_indices,
to_indices,
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
@@ -456,33 +635,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Build eval dataloader if a held-out split exists
eval_dataloader = None
if eval_dataset is not None:
eval_ds = eval_dataset
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
unique_tasks = sorted(set(task_arr.tolist()))
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
selected: list[int] = []
for t in unique_tasks:
frames = (task_arr == t).nonzero()[0][:per_task]
selected.extend(frames.tolist())
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
eval_dataloader = torch.utils.data.DataLoader(
eval_ds,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=eval_collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
@@ -492,23 +644,61 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy.train()
# ------------------------------------------------------------------
# EMA setup
# ------------------------------------------------------------------
# Shadow copy of the trainable params for late-training averaging
# (Chi et al. 2023 Diffusion Policy §V.D; openpi JAX trainer ships
# this with decay=0.999 for pi05_libero; openpi PyTorch port and
# LeRobot main both skip it). Off by default; opt in with
# ``--ema.enable=true``. Implemented via ema-pytorch
# (https://github.com/lucidrains/ema-pytorch) — the standard PyTorch
# EMA library, also used by lucidrains' diffusion repos.
ema = None
if cfg.ema.enable and is_main_process:
from ema_pytorch import EMA # noqa: PLC0415
ema = EMA(
accelerator.unwrap_model(policy),
beta=cfg.ema.decay,
update_after_step=cfg.ema.warmup_steps,
update_every=1, # update on every ema.update() call
# Don't register the live model as an ema submodule — accelerator
# already owns its lifecycle, and double-registration would
# double-count its params in ``ema.state_dict()``.
include_online_model=False,
)
ema.to(accelerator.device)
logging.info(
"EMA enabled (ema-pytorch): beta=%g, update_after_step=%d, "
"use_for_eval=%s, use_for_wandb_examples=%s",
cfg.ema.decay,
cfg.ema.warmup_steps,
cfg.ema.use_for_eval,
cfg.ema.use_for_wandb_examples,
)
# Resume the EMA shadow if a previous run wrote one.
if cfg.checkpoint_path is not None:
ema_path = cfg.checkpoint_path / "training_state" / "ema_state.pt"
if ema_path.exists():
logging.info("Resuming EMA shadow from %s", ema_path)
try:
ema.load_state_dict(torch.load(ema_path, map_location=accelerator.device))
except Exception as exc: # noqa: BLE001
logging.warning(
"Failed to load EMA shadow (%s) — restarting EMA from "
"current live weights",
exc,
)
train_metrics = {
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
# is actually optimizing. grad_norm and lr are already identical on every rank (post
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
# true straggler instead of rank 0's view.
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
# Derived from the post-reduce max step time; set once per log window on the main rank.
"samples_per_s": AverageMeter("smp/s", ":.0f"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
if torch.cuda.is_available():
# max() because headroom is gated by the worst-case rank.
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
# Keep global batch size for logging; MetricsTracker handles world size internally.
effective_batch_size = cfg.batch_size * accelerator.num_processes
@@ -554,58 +744,97 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sample_weighter=sample_weighter,
)
# EMA update: pull one step of the live weights into the shadow.
# Runs only on the main process (the shadow lives there); other
# ranks rely on the live model staying in sync via accelerator.
# ``ema-pytorch`` holds an internal reference to the online model
# (set at construction), so ``ema.update()`` takes no args.
if ema is not None:
ema.update()
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0
is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
# Optional periodic head-prediction dump for the LM head:
# ``LEROBOT_DEBUG_PREDS_EVERY=1000`` prints 5 samples + per-token
# (label, argmax, ✓/✗) every 1000 steps. Cheap diagnostic to see
# whether the text head is actually learning what we expect, vs
# collapsing to a fixed token. Refilling the recipe-sample dump
# budget at the same cadence also redumps the raw input shapes.
_debug_preds_every = int(os.environ.get("LEROBOT_DEBUG_PREDS_EVERY", "0"))
if (
_debug_preds_every > 0
and step % _debug_preds_every == 0
and is_main_process
):
try:
from lerobot.policies.pi052 import text_processor_pi052 as _tp # noqa: PLC0415
_tp._DUMPED_SO_FAR = 0
_tp._DUMP_BUDGET = max(_tp._DUMP_BUDGET, 5)
except Exception: # noqa: BLE001
pass
_print_debug_text_predictions(policy, batch, step, n_samples=5)
if is_log_step:
# Collective reduce must run on every rank, before the main-process gate below.
train_tracker.reduce_across_ranks()
if is_main_process:
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
# reflects the slowest rank — which is what actually gates the next iteration.
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
if step_time > 0:
train_tracker.samples_per_s = effective_batch_size / step_time
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
# EMA observability: ``ema.step`` is the count of
# ``ema.update()`` calls (= optimizer steps once EMA is
# enabled); ``ema.initted`` flips to True once we've
# crossed ``update_after_step``.
if ema is not None:
wandb_log_dict["ema/step"] = int(ema.step.item())
wandb_log_dict["ema/initted"] = float(ema.initted.item())
wandb_log_dict["ema/beta"] = float(cfg.ema.decay)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if is_eval_step:
policy.eval()
eval_loss_sum = 0.0
n_eval_batches = 0
with torch.no_grad(), accelerator.autocast():
for eval_batch in eval_dataloader:
for cam_key in dataset.meta.camera_keys:
if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8:
eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0
eval_batch = preprocessor(eval_batch)
loss, _ = policy.forward(eval_batch)
eval_loss_sum += loss.item()
n_eval_batches += 1
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
policy.train()
if is_main_process:
logging.info(f"step {step}: eval_loss={eval_loss:.4f}")
if wandb_logger:
wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval")
# Periodic training-example dump to wandb (camera images + text
# fields + action endpoints). Opt-in via ``--wandb.log_examples_freq``;
# independent of ``--log_freq`` so you can keep scalar logs frequent
# and the heavier visual dump rare (e.g. every 5000 steps).
if (
wandb_logger is not None
and cfg.wandb.log_examples_freq > 0
and step % cfg.wandb.log_examples_freq == 0
and is_main_process
):
try:
# Optionally use the EMA shadow model directly for the
# predicted-action columns (matches what eval / deployment
# would see). ``ema-pytorch`` exposes the shadow as a
# full ``nn.Module`` at ``ema.ema_model``, so we just
# pass that instead of swap-and-restore.
target_policy = (
ema.ema_model
if (ema is not None and cfg.ema.use_for_wandb_examples)
else accelerator.unwrap_model(policy)
)
wandb_logger.log_training_examples(
batch=batch,
step=step,
camera_keys=list(dataset.meta.camera_keys),
n_samples=cfg.wandb.log_examples_n,
policy=target_policy,
predict_actions=cfg.wandb.log_examples_predict_actions,
)
except Exception as exc: # noqa: BLE001
logging.warning("wandb log_training_examples failed: %s", exc)
if cfg.save_checkpoint and is_saving_step:
if is_main_process:
@@ -620,23 +849,43 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
num_processes=accelerator.num_processes,
batch_size=cfg.batch_size,
)
update_last_checkpoint(checkpoint_dir)
# Save the EMA shadow alongside the training state so a
# resumed run picks up exactly where the live EMA left off.
# ``ema-pytorch.state_dict()`` returns the full shadow
# nn.Module's state dict + step/initted buffers; saved as
# .pt (the rest of training_state mixes formats already).
if ema is not None:
try:
ema_path = checkpoint_dir / "training_state" / "ema_state.pt"
ema_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(ema.state_dict(), ema_path)
except Exception as exc: # noqa: BLE001
logging.warning("Failed to save EMA shadow: %s", exc)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
accelerator.wait_for_everyone()
if cfg.env and is_env_eval_step:
if cfg.env and is_eval_step:
if is_main_process:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
# Use the EMA shadow model for eval when enabled —
# standard practice for diffusion-style policies (~13%
# lift on closed-loop success). ``ema.ema_model`` is a
# full nn.Module clone, so we just pass it through; no
# swap/restore on the live policy needed.
eval_target_policy = (
ema.ema_model
if (ema is not None and cfg.ema.use_for_eval)
else accelerator.unwrap_model(policy)
)
with torch.no_grad(), accelerator.autocast():
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy),
policy=eval_target_policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
@@ -13,213 +13,77 @@
[SmolVLA](https://huggingface.co/papers/2506.01844) is a compact, efficient vision-language-action model that achieves competitive performance at reduced computational costs and can be deployed on consumer-grade hardware.
{% elif model_name == "act" %}
[Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates.
{% elif model_name == "tdmpc" %}
[TD-MPC](https://huggingface.co/papers/2203.04955) combines model-free and model-based approaches to improve sample efficiency and performance in continuous control tasks by using a learned latent dynamics model and terminal value function.
{% elif model_name == "diffusion" %}
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
{% elif model_name == "vqbet" %}
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
{% elif model_name == "pi0" %}
[π₀ (Pi0)](https://www.physicalintelligence.company/blog/pi0) is a general-purpose robot foundation model from Physical Intelligence: a generalist Vision-Language-Action policy that understands visual inputs, interprets natural language instructions, and controls a variety of different robots across diverse tasks. The LeRobot implementation is adapted from their open-source OpenPI repository.
**π₀ (Pi0)**
π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
**Model Overview**
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0).
{% elif model_name == "pi05" %}
[π₀.₅ (Pi05)](https://www.physicalintelligence.company/blog/pi05) is a Vision-Language-Action model from Physical Intelligence designed for open-world generalization: it evolves π₀ to generalize to entirely new environments and situations that were never seen during training. The LeRobot implementation is adapted from their open-source OpenPI repository.
{% elif model_name == "molmoact2" %}
[MolmoAct2](https://allenai.org/blog/molmoact2) is an open robotics foundation model from the Allen Institute for AI (Ai2) that maps camera images and language instructions to robot action chunks. The LeRobot implementation supports training and evaluation of the regular MolmoAct2 model.
{% elif model_name == "vla_jepa" %}
[VLA-JEPA](https://arxiv.org/abs/2602.10098) is a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
**π₀.₅ (Pi05) Policy**
π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
**Model Overview**
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
{% elif model_name == "gaussian_actor" %}
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
{% elif model_name == "pi0_fast" %}
[π₀-FAST (Pi0-FAST)](https://www.physicalintelligence.company/research/fast) is a Vision-Language-Action model for general robot control, from Physical Intelligence. It models continuous robot actions with autoregressive next-token prediction using FAST (Frequency-space Action Sequence Tokenization), training up to 5x faster than diffusion-based π₀.
{% elif model_name == "eo1" %}
[EO-1](https://huggingface.co/papers/2508.21112) is a Vision-Language-Action model for general robot control. It pairs a Qwen2.5-VL backbone for vision-language understanding with a continuous flow-matching action head that denoises action chunks.
{% elif model_name == "groot" %}
[GR00T N1.5](https://github.com/NVIDIA/Isaac-GR00T) is an open, cross-embodiment foundation model from NVIDIA for generalized humanoid robot reasoning and skills. It takes language and images as input and uses a flow-matching action transformer to predict actions conditioned on vision, language, and proprioception.
{% elif model_name == "multi_task_dit" %}
[Multi-Task Diffusion Transformer (DiT)](https://huggingface.co/papers/2507.05331) extends Diffusion Policy with a large Diffusion Transformer and text + vision conditioning for multi-task robot learning. It supports both diffusion and flow-matching objectives and reaches high dexterity with only ~450M parameters.
{% elif model_name == "wall_x" %}
[WALL-OSS](https://huggingface.co/papers/2509.11766) is an open-source foundation model for embodied intelligence from XSquare Robot. Built on Qwen2.5-VL, it uses a tightly-coupled multimodal architecture with flow matching to unify semantic reasoning and high-frequency action generation for cross-embodiment control.
{% elif model_name == "xvla" %}
[X-VLA](https://huggingface.co/papers/2510.10274) is a soft-prompted, flow-matching Vision-Language-Action framework that treats each robot or hardware setup as a "task" encoded with a small set of learnable Soft Prompt embeddings, letting a single model reconcile diverse robot morphologies, sensors, and action spaces.
{% else %}
This is a **{{ model_name }}** policy trained with [LeRobot](https://github.com/huggingface/lerobot).
_Model type not recognized — please update this template._
{% endif %}
{% set diagrams = {
"smolvla": "https://cdn-uploads.huggingface.co/production/uploads/640e21ef3c82bd463ee5a76d/aooU0a3DMtYmy_1IWMaIM.png",
"pi0": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pi0%20(1).png",
"pi0_fast": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pifast.png",
"eo1": "https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png",
"groot": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png",
"wall_x": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/walloss-lerobot-paper.png",
"xvla": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
} %}
{% if diagrams.get(model_name) %}
<p align="center">
<img src="{{ diagrams[model_name] }}" alt="{{ model_name }} architecture" width="85%"/>
</p>
{% endif %}
<!-- A short demo is worth more than any description! Record a GIF/video of the policy
running on your robot, upload it to this repo, and embed it here:
<p align="center">
<img src="https://huggingface.co/<hf_user>/<policy_repo_id>/resolve/main/demo.gif" width="60%"/>
</p>
-->
This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot).
{% set policy_docs = {
"act": "act",
"smolvla": "smolvla",
"pi0": "pi0",
"pi0_fast": "pi0fast",
"pi05": "pi05",
"molmoact2": "molmoact2",
"vla_jepa": "vla_jepa",
"eo1": "eo1",
"groot": "groot",
"xvla": "xvla",
"multi_task_dit": "multi_task_dit",
"wall_x": "walloss"
} %}
{% if policy_docs.get(model_name) %}Learn how to train and run it in the [LeRobot {{ model_name }} guide](https://huggingface.co/docs/lerobot/main/en/{{ policy_docs[model_name] }}), or browse the [full documentation](https://huggingface.co/docs/lerobot/index).
{% else %}See the [full LeRobot documentation](https://huggingface.co/docs/lerobot/index).
{% endif %}
See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index).
---
## How to Get Started with the Model
For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy).
Below is the short version on how to train and run inference/eval:
### Train from scratch
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/<dataset> \
--policy.type=act \
--output_dir=outputs/train/<desired_policy_repo_id> \
--job_name=lerobot_training \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/<desired_policy_repo_id>
--wandb.enable=true
```
_Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
### Evaluate the policy/run inference
```bash
lerobot-record \
--robot.type=so100_follower \
--dataset.repo_id=<hf_user>/eval_<dataset> \
--policy.path=<hf_user>/<desired_policy_repo_id> \
--episodes=10
```
Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
---
## Model Details
- **License:** {{ license | default("\[More Information Needed]", true) }}
{% if base_model %}- **Fine-tuned from:** [{{ base_model }}](https://huggingface.co/{{ base_model }})
{% endif %}{% if robot_type %}- **Robot type:** `{{ robot_type }}`
{% endif %}{% if cameras %}- **Cameras:** {% for camera in cameras %}`{{ camera }}`{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
{% if input_features or output_features %}
## Inputs & Outputs
The policy consumes these observation features and produces these action features.
{% if input_features %}
**Inputs**
| Feature | Type | Shape |
| --- | --- | --- |
{% for name, feature in input_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
{% endfor %}{% endif %}{% if output_features %}
**Outputs**
| Feature | Type | Shape |
| --- | --- | --- |
{% for name, feature in output_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` |
{% endfor %}{% endif %}{% endif %}
{% if dataset %}
## Training Dataset
- **Repository:** [{{ dataset.repo_id }}](https://huggingface.co/datasets/{{ dataset.repo_id }})
- **Episodes:** {{ dataset.episodes }}
- **Frames:** {{ dataset.frames }}
- **Frame rate:** {{ dataset.fps }} FPS
{% if dataset.tasks %}- **Task(s):** {% for task in dataset.tasks %}"{{ task }}"{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ dataset.repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
{% if training %}
## Training Configuration
| Setting | Value |
| --- | --- |
| Training steps | {{ training.steps }} |
| Batch size | {{ training.batch_size }} |
{% if training.optimizer %}| Optimizer | {{ training.optimizer }} |
{% endif %}{% if training.lr %}| Learning rate | {{ training.lr }} |
{% endif %}{% if training.seed is not none %}| Seed | {{ training.seed }} |
{% endif %}| LeRobot version | {{ training.lerobot_version }} |
{% endif %}
---
## How to Get Started with the Model
New to LeRobot? These guides cover the full workflow:
- **[Install LeRobot](https://huggingface.co/docs/lerobot/main/en/installation)** — set up the `lerobot` package.
- **[Hardware setup](https://huggingface.co/docs/lerobot/main/en/hardware_guide)** — assemble, wire, and calibrate your robot and cameras.
- **[Record data & train a policy](https://huggingface.co/docs/lerobot/en/il_robots)** — the end-to-end imitation-learning walkthrough.
- **[CLI cheat-sheet](https://huggingface.co/docs/lerobot/main/en/cheat-sheet)** — quick reference for the `lerobot-*` commands.
The short version to run and train this policy:
### Run the policy on your robot
```bash
lerobot-rollout \
--strategy.type=base \
--robot.type={{ robot_type | default("<your_robot_type>", true) }} \
--robot.port=<your_robot_port> \
--robot.cameras="{ <camera_1>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}, <camera_2>: {type: opencv, index_or_path: <index_or_path>, width: 640, height: 480, fps: 30}}" \
--policy.path={{ policy_repo_id | default("<hf_user>/<policy_repo_id>", true) }} \
--task="{% if dataset and dataset.tasks %}{{ dataset.tasks[0] }}{% else %}<your_task_description>{% endif %}" \
--duration=60
```
Replace the remaining `<...>` placeholders with your own values: `--robot.port` and the camera names/indices are specific to your machine, and the camera names must match the observation keys this policy was trained on.
When `--strategy.type=base` is used the script doesn't record the episodes. Skipping duration will make the policy run indefinitely. For more information look at [rollout documentation](https://huggingface.co/docs/lerobot/main/en/inference).
{% if base_model %}### Train your own policy
This policy type is usually fine-tuned from the pretrained base model [{{ base_model }}](https://huggingface.co/{{ base_model }}):
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/<dataset> \
--policy.path={{ base_model }} \
--output_dir=outputs/train/<policy_repo_id> \
--job_name=lerobot_training \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/<policy_repo_id> \
--wandb.enable=true
```
{% else %}### Train your own policy
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/<dataset> \
--policy.type={{ model_name }} \
--output_dir=outputs/train/<policy_repo_id> \
--job_name=lerobot_training \
--policy.device=cuda \
--policy.repo_id=${HF_USER}/<policy_repo_id> \
--wandb.enable=true
```
{% endif %}
_Writes checkpoints to `outputs/train/<policy_repo_id>/checkpoints/`._
---
## Evaluation
<!-- Report real-robot results here: run the policy several times per task and count the
successes. Delete the "No evaluation results" line and fill in this table instead:
| Task | Trials | Successes | Success rate |
| ---- | ------ | --------- | ------------ |
| pick the lego brick | 10 | 8 | 80% |
Also worth noting: anything that affects difficulty (new object positions, lighting,
distractors, a different robot of the same type, ...).
-->
_No evaluation results have been provided for this policy yet._
---
## Citation
If you use this policy, please cite the method linked in the description above, along with LeRobot:
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
}
```
+29
View File
@@ -0,0 +1,29 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LeRobot tool implementations.
Storage of the tool catalog (``meta/info.json["tools"]``) and the
``SAY_TOOL_SCHEMA`` constant live in PR 1
(``lerobot.datasets.language``). This package holds the *runnable*
implementations one file per tool, plus the registry that maps tool
names to classes.
See ``docs/source/tools.mdx`` for the authoring guide.
"""
from .base import Tool
from .registry import TOOL_REGISTRY, get_tools
from .say import SayTool
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
+58
View File
@@ -0,0 +1,58 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool protocol — the contract every runnable tool implementation honors.
Tools are the executable side of the OpenAI-style function-calling
abstraction the v3.1 language schema (PR 1) carries on assistant
messages: the schema describes *what can be called*, the tool
implementation describes *how to call it*.
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
``say.py`` for ``SayTool``) and are registered in
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
heavy dependencies (torch models, audio backends, network clients,
hardware drivers) only load when the dataset actually declares the tool.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class Tool(Protocol):
"""Minimum surface every tool must expose."""
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
#: routes incoming ``tool_calls`` to the implementation by this key.
name: str
#: OpenAI-style function-call schema. Same dict the dataset stores in
#: ``meta/info.json["tools"]`` and the chat template renders into the
#: prompt.
schema: dict[str, Any]
def call(self, arguments: dict[str, Any]) -> Any:
"""Execute the tool with the model-provided arguments.
``arguments`` is the parsed dict from
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
when the model emits a JSON-string by the chat-template
convention). Implementations validate the dict against their own
schema; the runtime only routes by name.
Return value is implementation-defined typically a tensor
(TTS audio), a Path (saved file), a dict (structured result), or
``None`` (side-effect-only call).
"""
+70
View File
@@ -0,0 +1,70 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool registry — name → implementation class.
Adding a new tool:
1. Drop a file under ``src/lerobot/tools/`` that defines a class
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
``schema``, ``call(arguments)``).
2. Register the class here under :data:`TOOL_REGISTRY`.
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
to advertise the schema to the chat-template + policy. The PR 2
annotation pipeline preserves anything you put there.
See ``docs/source/tools.mdx`` for the full authoring guide.
"""
from __future__ import annotations
from typing import Any
from .base import Tool
from .say import SayTool
#: Map from ``function.name`` to a class implementing :class:`Tool`.
#: The runtime instantiates entries lazily — registering a tool here is
#: essentially free (no model load happens until ``call`` runs).
TOOL_REGISTRY: dict[str, type] = {
"say": SayTool,
}
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
"""Build name → tool-instance dict from a dataset's declared catalog.
``meta`` is anything with a ``.tools`` attribute returning the
OpenAI-style schema list typically a
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
Each entry whose ``function.name`` is registered here is
instantiated with the schema dict; tools whose name is unknown to
the registry are skipped (the schema still rides through the chat
template, the model just can't actually invoke that tool at
inference).
Extra keyword arguments are forwarded to every constructor useful
for runtime defaults like ``output_dir=Path("./tts_log")``.
"""
declared = list(meta.tools)
instances: dict[str, Tool] = {}
for schema in declared:
try:
name = schema["function"]["name"]
except (KeyError, TypeError):
continue
cls = TOOL_REGISTRY.get(name)
if cls is None:
continue
instances[name] = cls(schema=schema, **kwargs)
return instances
+169
View File
@@ -0,0 +1,169 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
The first concrete tool implementation. PI052 and downstream runtime
dispatchers consume this when the model emits an assistant message
with ``tool_calls=[{function: {name: "say", arguments: {text: ...}}}]``.
Why pocket-tts:
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
- ~100M parameters, ~200ms first-chunk latency
- streamable, voice-cloneable
- pip-installable, MIT-style permissive license
The pocket-tts model is loaded **lazily** the first time ``call(...)``
runs (or eagerly via ``preload()``). Loading takes a few seconds and
several hundred MB of RAM, so we don't pay the cost when the tool is
merely *registered* only when it's *invoked*.
Optional dependency. Install with::
pip install lerobot[tools]
# or directly:
pip install pocket-tts
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import SAY_TOOL_SCHEMA
logger = logging.getLogger(__name__)
@dataclass
class SayTool:
"""Speak a short utterance via Kyutai's pocket-tts.
Parameters
----------
schema:
Optional schema override; defaults to the canonical
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
argument shapes can pass in a modified schema, but the
implementation only reads ``arguments["text"]``.
voice:
One of the pocket-tts catalog voices (``alba``, ``marius``,
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
file for cloning. See the pocket-tts model card for licensing.
output_dir:
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
file there in addition to returning the PCM tensor.
``None`` (default) skips disk writes useful for live
playback paths that hand the tensor directly to a sounddevice
/ WebAudio sink.
"""
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
voice: str = "alba"
output_dir: Path | None = None
name: str = field(init=False, default="say")
_model: Any = field(init=False, default=None, repr=False)
_voice_state: Any = field(init=False, default=None, repr=False)
_sample_rate: int = field(init=False, default=24000, repr=False)
# ------------------------------------------------------------------
# Lazy model load
# ------------------------------------------------------------------
def preload(self) -> None:
"""Load the pocket-tts model + voice state into memory.
Optional ``call(...)`` triggers this automatically on first
invocation. Useful when you want the multi-second load to
happen at startup rather than on the first ``say`` the policy
emits.
"""
if self._model is not None and self._voice_state is not None:
return
try:
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
except ImportError as exc: # pragma: no cover (env-dependent)
raise ImportError(
"SayTool requires pocket-tts. Install with `pip install "
"lerobot[tools]` or `pip install pocket-tts`."
) from exc
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
self._model = TTSModel.load_model()
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
# ------------------------------------------------------------------
# Tool protocol
# ------------------------------------------------------------------
def call(self, arguments: dict[str, Any]) -> Any:
"""Speak ``arguments["text"]`` and return the PCM tensor.
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
``self.output_dir`` is set. The returned tensor is a 1-D
``torch.Tensor`` of float32 PCM samples at
``self.sample_rate`` Hz directly playable by
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
encodable by ``scipy.io.wavfile.write``.
"""
text = arguments.get("text")
if not isinstance(text, str) or not text.strip():
raise ValueError(
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
)
self.preload()
audio = self._model.generate_audio(self._voice_state, text)
if self.output_dir is not None:
self._write_wav(audio, text)
return audio
@property
def sample_rate(self) -> int:
"""PCM sample rate of the returned tensor (Hz)."""
return self._sample_rate
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _write_wav(self, audio: Any, text: str) -> Path:
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
import time as _time # noqa: PLC0415
try:
import scipy.io.wavfile # noqa: PLC0415
except ImportError as exc: # pragma: no cover
raise ImportError(
"SayTool.output_dir requires scipy. `pip install scipy`."
) from exc
out_dir = Path(self.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# One file per call; suffix with a millisecond timestamp + a
# short text snippet so a directory listing is informative.
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
ts_ms = int(_time.time() * 1000)
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
# ``.numpy()`` is safe.
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
return path
+1 -1
View File
@@ -22,7 +22,7 @@ from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
+1
View File
@@ -34,6 +34,7 @@ ACTION = "action"
ACTION_PREFIX = ACTION + "."
ACTION_TOKENS = ACTION + ".tokens"
ACTION_TOKEN_MASK = ACTION + ".token_mask"
ACTION_CODE_TOKEN_MASK = ACTION + ".code_token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
+1 -50
View File
@@ -13,39 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Callable
from typing import Any
import torch
from .utils import format_big_number
_VALID_REDUCTIONS = ("none", "max", "mean", "sum")
class AverageMeter:
"""
Computes and stores the average and current value
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
Args:
name: Display name of the metric.
fmt: Format string used when rendering the metric.
reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks`
before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``,
or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or
update wall time) so multi-GPU runs report the slowest rank rather than rank 0.
"""
def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"):
if reduction not in _VALID_REDUCTIONS:
raise ValueError(
f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}."
)
def __init__(self, name: str, fmt: str = ":f"):
self.name = name
self.fmt = fmt
self.reduction = reduction
self.reset()
def reset(self) -> None:
@@ -156,37 +138,6 @@ class MetricsTracker:
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def reduce_across_ranks(self) -> None:
"""
Synchronises the running averages of every metric whose ``reduction`` is not ``"none"``
across all distributed processes (in-place).
This is a collective operation and MUST be invoked on every rank typically just before
logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics
reported by the main process only reflect rank 0; for bottleneck-style timings
(``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible.
"""
if self.accelerator is None or self.accelerator.num_processes <= 1:
return
buckets: dict[str, list[str]] = defaultdict(list)
for name, meter in self.metrics.items():
if meter.reduction != "none":
buckets[meter.reduction].append(name)
if not buckets:
return
device = self.accelerator.device
for reduction, names in buckets.items():
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
reduced = self.accelerator.reduce(tensor, reduction=reduction)
for name, value in zip(names, reduced.tolist(), strict=True):
meter = self.metrics[name]
# Preserve avg == sum / count so a later .update() on this meter accumulates
# against the cluster view, not the stale per-rank history.
meter.avg = value
meter.sum = value * meter.count
def __str__(self) -> str:
display_list = [
f"step:{format_big_number(self.steps)}",
+41 -108
View File
@@ -38,20 +38,19 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.annotations.steerable_pipeline.frames import VideoFrameProvider # noqa: E402
from lerobot.annotations.steerable_pipeline.frames import ( # noqa: E402
VideoFrameProvider,
_decode_frames_av,
_decode_frames_ffmpeg,
)
class _FakeMeta:
"""Minimal metadata stub exposing ``video_keys`` / ``camera_keys``."""
def __init__(self, video_keys: list[str], image_keys: list[str], video_path: Path | None = None) -> None:
def __init__(self, video_keys: list[str], image_keys: list[str]) -> None:
self.video_keys = video_keys
self.camera_keys = [*video_keys, *image_keys]
self._video_path = video_path
self.episodes = {0: {f"videos/{key}/from_timestamp": 0.0 for key in video_keys}}
def get_video_file_path(self, episode_index: int, camera_key: str) -> Path:
return self._video_path
def test_default_camera_key_skips_image_only_cameras(tmp_path: Path, monkeypatch) -> None:
@@ -125,24 +124,15 @@ def sample_video(tmp_path: Path) -> Path:
return out
def _provider_for_video(tmp_path: Path, video: Path, monkeypatch) -> VideoFrameProvider:
"""A provider whose single camera resolves to ``video`` via fake metadata."""
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=video)
import lerobot.datasets.dataset_metadata as meta_mod
def test_decode_frames_av_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
"""``_decode_frames_av`` decodes via PyAV directly — no torchcodec/torchvision.
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
return VideoFrameProvider(root=tmp_path, tolerance_s=0.2)
def test_decode_returns_one_uint8_frame_per_timestamp(
sample_video: Path, tmp_path: Path, monkeypatch
) -> None:
"""``_decode`` routes through ``decode_video_frames`` (torchcodec when
available, PyAV otherwise) no subprocess fallback.
This is the always-available fallback: torchcodec is unusable in some
containers and lerobot's ``pyav`` backend routes through the removed
``torchvision.io.VideoReader``.
"""
provider = _provider_for_video(tmp_path, sample_video, monkeypatch)
timestamps = [0.0, 1.0, 2.5]
frames = provider._decode(0, timestamps, "observation.images.cam")
frames = _decode_frames_av(sample_video, timestamps)
assert len(frames) == len(timestamps)
for frame in frames:
@@ -151,96 +141,39 @@ def test_decode_returns_one_uint8_frame_per_timestamp(
assert frame.shape == (3, 120, 160)
def test_frames_at_snaps_mid_frame_grid_to_real_frames(
sample_video: Path, tmp_path: Path, monkeypatch
) -> None:
"""Uniform sampling grids land mid-frame; ``frames_at`` must snap them to
real frame timestamps before decoding.
Regression: ``decode_video_frames`` rejects queries farther than
``tolerance_s`` (default 10 ms) from a decodable frame, so un-snapped
mid-frame queries raised ``FrameTimestampError`` wholesale and the plan
module silently lost its contact sheets for most episodes.
"""
from types import SimpleNamespace
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=sample_video)
import lerobot.datasets.dataset_metadata as meta_mod
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
provider = VideoFrameProvider(root=tmp_path) # default 10 ms tolerance
# 10 fps fixture -> frames at 0.0, 0.1, ...; queries sit mid-frame.
record = SimpleNamespace(episode_index=0, frame_timestamps=[i / 10 for i in range(30)])
frames = provider.frames_at(record, [0.149, 1.234, 2.04], camera_key="observation.images.cam")
def test_decode_frames_av_picks_nearest_frame(sample_video: Path) -> None:
"""Repeated and out-of-order timestamps each resolve to the nearest frame."""
frames = _decode_frames_av(sample_video, [2.0, 0.0, 2.0])
assert len(frames) == 3
assert torch.equal(frames[0], frames[2])
assert not torch.equal(frames[0], frames[1])
def test_decode_frames_av_raises_on_missing_file(tmp_path: Path) -> None:
"""A missing video surfaces as an exception the caller can fall back on."""
with pytest.raises(Exception): # noqa: B017, PT011
_decode_frames_av(tmp_path / "does_not_exist.mp4", [0.0])
def test_decode_frames_ffmpeg_returns_one_uint8_frame_per_timestamp(sample_video: Path) -> None:
"""``_decode_frames_ffmpeg`` shells out to the ffmpeg CLI — the always-
available fallback that decodes AV1 and isolates crashes to a child
process.
"""
timestamps = [0.0, 1.0, 2.5]
frames = _decode_frames_ffmpeg(sample_video, timestamps)
assert len(frames) == len(timestamps)
for frame in frames:
assert isinstance(frame, torch.Tensor)
assert frame.dtype == torch.uint8
assert frame.shape == (3, 120, 160)
def test_decode_returns_empty_list_on_missing_file(tmp_path: Path, monkeypatch) -> None:
"""A missing video is a recoverable no-frames condition, never a crash."""
provider = _provider_for_video(tmp_path, tmp_path / "does_not_exist.mp4", monkeypatch)
assert provider._decode(0, [0.0], "observation.images.cam") == []
def test_episode_clip_path_trims_via_reencode_video(tmp_path: Path, monkeypatch) -> None:
"""Clip extraction delegates to ``video_utils.reencode_video`` with the
episode's ``[from_timestamp, to_timestamp)`` trim window — no subprocess.
"""
from types import SimpleNamespace
import lerobot.annotations.steerable_pipeline.frames as frames_mod
src = tmp_path / "src.mp4"
src.write_bytes(b"src")
fake = _FakeMeta(video_keys=["observation.images.cam"], image_keys=[], video_path=src)
fake.episodes[0]["videos/observation.images.cam/from_timestamp"] = 1.5
fake.episodes[0]["videos/observation.images.cam/to_timestamp"] = 4.0
import lerobot.datasets.dataset_metadata as meta_mod
monkeypatch.setattr(meta_mod, "LeRobotDatasetMetadata", lambda *a, **k: fake, raising=True)
captured = {}
def fake_reencode(
input_video_path,
output_video_path,
camera_encoder=None,
overwrite=False,
start_time_s=None,
end_time_s=None,
):
captured.update(
src=Path(input_video_path),
encoder=camera_encoder,
start_time_s=start_time_s,
end_time_s=end_time_s,
)
Path(output_video_path).write_bytes(b"clip")
monkeypatch.setattr(frames_mod, "reencode_video", fake_reencode, raising=True)
provider = VideoFrameProvider(root=tmp_path)
record = SimpleNamespace(episode_index=0, frame_timestamps=[0.0, 1.0])
out = provider.episode_clip_path(record, tmp_path / "clips")
assert out == tmp_path / "clips" / "ep_000000.mp4"
assert captured["src"] == src
assert captured["start_time_s"] == 1.5
assert captured["end_time_s"] == 4.0
# H.264 so the clip is decodable by vllm's libav build (sources are often AV1).
assert captured["encoder"].vcodec == "h264"
def test_videoframeprovider_serializes_decodes_with_a_lock() -> None:
"""torchcodec's cached per-file decoder is single-threaded; the provider
must own a dedicated lock that ``_decode`` holds around the decoder call.
"""
import threading
lock_field = VideoFrameProvider.__dataclass_fields__.get("_decode_lock")
assert lock_field is not None
assert lock_field.default_factory is threading.Lock
def test_decode_frames_ffmpeg_raises_on_missing_file(tmp_path: Path) -> None:
"""A missing video raises (non-zero ffmpeg exit), never crashes the job."""
if shutil.which("ffmpeg") is None:
pytest.skip("ffmpeg not available")
with pytest.raises(Exception): # noqa: B017, PT011
_decode_frames_ffmpeg(tmp_path / "does_not_exist.mp4", [0.0])
+13 -45
View File
@@ -22,7 +22,6 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import PIL.Image
import pytest
# ``lerobot.annotations`` imports pull in ``lerobot.datasets`` (-> the HF
@@ -52,10 +51,7 @@ from ._helpers import make_canned_responder # noqa: E402
class _StubFrameProvider:
"""Returns one sentinel object per requested timestamp."""
# A real (tiny) PIL image so the contact-sheet builder, which resizes and
# tiles frames, has something to draw. VQA still passes it through by
# identity via ``to_image_blocks``.
sentinel: Any = field(default_factory=lambda: PIL.Image.new("RGB", (32, 24)))
sentinel: Any = field(default_factory=lambda: object())
cameras: tuple[str, ...] = ("observation.images.top",)
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
@@ -119,34 +115,6 @@ def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path:
assert len(plan_rows[-1]["content"].splitlines()) == 1
def test_module1_emit_memory_false_skips_memory_keeps_subtasks_and_plan(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
"""``emit_memory=False`` drops ``memory`` rows (and their VLM calls) while
leaving subtask + plan generation intact symmetric to ``emit_plan``."""
vlm = make_canned_responder(
{
"atomic subtasks": {
"subtasks": [
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
]
},
"compressed semantic memory": {"memory": "wiped the counter once"},
},
)
module = PlanSubtasksMemoryModule(vlm=vlm, config=PlanConfig(emit_memory=False))
record = next(iter_episodes(fixture_dataset_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
rows = staging.read("plan")
styles = {r["style"] for r in rows}
assert "memory" not in styles
assert {"subtask", "plan"}.issubset(styles)
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
vlm = make_canned_responder(
{"acknowledgement the robot": {"text": "Sure, on it."}},
@@ -268,10 +236,8 @@ def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_
assert ts in frame_set
def test_module1_attaches_contact_sheets_to_subtask_prompt(
fixture_dataset_root: Path, tmp_path: Path
) -> None:
"""Module 1 sends timestamped contact-sheet image blocks (not a raw video block)."""
def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Path, tmp_path: Path) -> None:
"""Module 1 sends one ``type=video`` block covering the whole episode."""
captured: list[list[dict[str, Any]]] = []
payload = {
"subtasks": [
@@ -299,7 +265,7 @@ def test_module1_attaches_contact_sheets_to_subtask_prompt(
# call is the subtask one — keeps the assertions below focused on
# ``_generate_subtasks`` rather than fighting the order of unrelated
# text-only Module-1 sub-prompts.
config=PlanConfig(frames_per_second=2.0, max_frames_per_prompt=60, n_task_rephrasings=0),
config=PlanConfig(max_video_frames=5, frames_per_second=10.0, n_task_rephrasings=0),
frame_provider=provider,
)
record = next(iter_episodes(fixture_dataset_root))
@@ -324,14 +290,16 @@ def test_module1_attaches_contact_sheets_to_subtask_prompt(
video_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "video"]
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
assert video_blocks == [], "contact-sheet mode must not emit a raw video block"
assert len(image_blocks) >= 1, f"expected >=1 contact-sheet image block, got {content}"
assert all(isinstance(b["image"], PIL.Image.Image) for b in image_blocks)
assert len(video_blocks) == 1, f"expected exactly 1 video block, got {content}"
assert image_blocks == [], "subtask prompt must not mix image blocks with the video block"
assert len(text_blocks) == 1
# the prompt is prefixed with the contact-sheet reading instructions
assert text_blocks[0]["text"].startswith("CONTACT SHEETS")
# frames were decoded for this episode at episode-relative timestamps
assert provider.calls and provider.calls[0][0] == record.episode_index
# video block must wrap a list of frames covering the episode
assert isinstance(video_blocks[0]["video"], list)
assert len(video_blocks[0]["video"]) <= 5
# provider is called with target_count = min(duration * fps, max). With
# fps=10 on a ~1s episode that requests >max, so max=5 wins.
assert provider.video_calls and provider.video_calls[0][0] == record.episode_index
assert provider.video_calls[0][1] <= 5
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
-41
View File
@@ -1,41 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ``vlm_client`` helpers."""
from __future__ import annotations
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.annotations.steerable_pipeline.vlm_client import _bind_serve_port # noqa: E402
def test_bind_serve_port_substitutes_placeholder() -> None:
# The {port} placeholder is replaced everywhere it appears, regardless of
# parallel vs single server — the bug was the single-server path passing
# it through unsubstituted.
cmd = "vllm serve M --max-model-len 32768 --port {port}"
assert _bind_serve_port(cmd, 8000) == "vllm serve M --max-model-len 32768 --port 8000"
def test_bind_serve_port_appends_when_missing() -> None:
assert _bind_serve_port("vllm serve M", 8001) == "vllm serve M --port 8001"
def test_bind_serve_port_leaves_explicit_port_untouched() -> None:
cmd = "vllm serve M --port 9000"
assert _bind_serve_port(cmd, 8000) == cmd
+9
View File
@@ -29,6 +29,15 @@ def test_message_recipe_validates_unknown_binding():
)
def test_canonical_recipe_loads():
"""The canonical PI052 blend YAML loads + validates."""
recipe = TrainingRecipe.from_yaml(
Path("src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml")
)
assert recipe.blend is not None
assert sum(c.weight for c in recipe.blend.values()) == pytest.approx(1.0)
def test_message_turn_requires_a_stream():
"""Every turn must declare a stream — None is rejected at construction.
-46
View File
@@ -289,52 +289,6 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_datasets_without_concatenation(tmp_path, lerobot_dataset_factory):
"""With concatenation disabled, each source file is kept as its own destination file."""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "no_stitch_0",
repo_id=f"{DUMMY_REPO_ID}_no_stitch_0",
total_episodes=3,
total_frames=60,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "no_stitch_1",
repo_id=f"{DUMMY_REPO_ID}_no_stitch_1",
total_episodes=4,
total_frames=80,
)
aggr_root = tmp_path / "no_stitch_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_no_stitch_aggr",
aggr_root=aggr_root,
concatenate_videos=False,
concatenate_data=False,
)
with (
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(aggr_root)
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_no_stitch_aggr", root=aggr_root)
assert_episode_and_frame_counts(
aggr_ds, ds_0.num_episodes + ds_1.num_episodes, ds_0.num_frames + ds_1.num_frames
)
assert_dataset_iteration_works(aggr_ds)
assert_video_timestamps_within_bounds(aggr_ds)
# Two single-file sources stay as two files each, instead of being packed together.
assert len(list((aggr_root / "data").rglob("*.parquet"))) == 2
assert aggr_ds.meta.video_keys, "Test fixture should produce at least one video feature"
for key in aggr_ds.meta.video_keys:
assert len(list((aggr_root / "videos" / key).rglob("*.mp4"))) == 2
@pytest.mark.parametrize("mutation", ["mismatched_value", "missing_key"])
def test_aggregate_incomplete_video_encoder_info_warns_and_nuls_encoders(
tmp_path, lerobot_dataset_factory, caplog, mutation
-23
View File
@@ -83,29 +83,6 @@ def test_get_feature_stats_images():
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_uint8_images_preserves_std():
data = np.array(
[
[
[[0, 64], [128, 255]],
[[255, 128], [64, 0]],
[[32, 96], [160, 224]],
],
[
[[16, 80], [144, 240]],
[[240, 144], [80, 16]],
[[48, 112], [176, 208]],
],
],
dtype=np.uint8,
)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
expected_std = data.transpose(0, 2, 3, 1).reshape(-1, 3).std(axis=0).reshape(1, 3, 1, 1)
np.testing.assert_allclose(stats["std"], expected_std)
def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = {
"min": np.array([[1, 2, 3]]),
+78
View File
@@ -343,6 +343,84 @@ def test_resolve_task_explicit_override_beats_rephrasings():
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_flow_only_low_level_recipe_renders_without_target():
"""Regression: a flow-only ``low_level`` recipe has no ``target`` turn —
its supervision is the action-expert flow loss, not text-CE. It must
still render (not ``None``), otherwise every blend draw of it is dropped
and the action expert never receives a flow loss."""
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="user",
content="${subtask}",
stream="low_level",
if_present="subtask",
),
],
bindings={"subtask": "active_at(t, style=subtask)"},
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=[],
t=0.5,
sample_idx=0,
task="clean kitchen",
)
assert rendered is not None
assert rendered["messages"] == [{"role": "user", "content": "subtask 0"}]
assert rendered["message_streams"] == ["low_level"]
assert rendered["target_message_indices"] == []
def test_vqa_frame_is_consumed_over_the_weighted_blend():
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
even when its blend weight is tiny VQA annotations are sparse and must
never be wasted on a subtask/action draw."""
recipe = TrainingRecipe(
blend={
"high_level_subtask": TrainingRecipe(
weight=0.99,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
],
),
"ask_vqa_top": TrainingRecipe(
weight=0.01,
bindings={
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
},
messages=[
MessageTurn(
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
}
)
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
# ask_vqa weight being only 0.01.
for sample_idx in range(20):
rendered = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
)
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
assert rendered["messages"][-1]["content"] == "a subtask"
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
+37 -88
View File
@@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
)
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
@@ -114,19 +114,6 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_is_reproducible_across_instances():
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
# produce the same permutation without any generator synchronization.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == epoch_0
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
@@ -152,85 +139,47 @@ def test_partial_episode_drop_warns(caplog):
assert "Episode 0" in caplog.text
# --- seeded (seed, epoch) shuffling, resume, and state ---
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
# --- WeightedEpisodeAwareSampler --------------------------------------------
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
for seed in (0, 1, 1234):
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
assert sorted(sampler) == list(range(num_frames))
def test_weighted_sampler_respects_episode_drop_and_length():
"""The episode-boundary frame filtering is applied before weighting,
and one epoch still yields ``len(indices)`` samples."""
# One episode, 10 frames; drop the last 2.
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
assert sampler.indices == list(range(8))
assert len(sampler) == 8
draws = list(sampler)
assert len(draws) == 8
# Dropped frames 8 and 9 must never be sampled.
assert all(d in set(range(8)) for d in draws)
def test_deterministic_sampler_epochs_reproduce_and_differ():
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
assert epoch_1 != epoch_0
assert sorted(epoch_1) == sorted(epoch_0)
sampler_a.set_epoch(0)
assert list(sampler_a) == epoch_0
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
def test_weighted_sampler_oversamples_high_weight_frames():
"""A heavily-weighted frame dominates the draws."""
torch.manual_seed(0)
# 100 frames, frame 7 is weighted 1000x.
weights = torch.ones(100)
weights[7] = 1000.0
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
counts = {}
for _ in range(20): # 20 epochs
for d in sampler:
counts[d] = counts.get(d, 0) + 1
total = sum(counts.values())
# Frame 7 should be the overwhelming majority of the 2000 draws.
assert counts.get(7, 0) / total > 0.9
def test_deterministic_sampler_resume_mid_epoch():
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
epoch_0 = list(reference)
epoch_1 = list(reference)
for start in (0, 1, 4, len(epoch_0)):
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
assert list(resumed) == epoch_1
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
"""If every surviving frame has zero weight, sampling is uniform
rather than crashing."""
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
draws = set(sampler)
assert draws.issubset(set(range(6)))
assert len(list(sampler)) == 6
def test_deterministic_sampler_construction_stores_only_boundaries():
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
# instantiates from just its boundaries without materializing a per-frame index list.
num_frames = 1_000_000
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
def test_deterministic_sampler_resume_is_exact_at_scale():
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
# permutation and slicing from the saved offset reproduces the remaining order exactly.
num_frames = 100_000
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
epoch_0 = list(reference)
assert sorted(epoch_0) == list(range(num_frames))
start = num_frames - 5
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
def test_compute_sampler_state():
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 0,
"start_index": 0,
}
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 40,
}
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
"epoch": 2,
"start_index": 40,
}
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 100,
}
def test_weighted_sampler_rejects_short_weight_vector():
with pytest.raises(ValueError, match="frame_weights"):
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))
-13
View File
@@ -504,19 +504,6 @@ class TestReencodeVideo:
assert info["video.g"] == 6
assert info["video.crf"] == 23
@require_h264
def test_reencode_video_trim_window(self, tmp_path):
src = TEST_ARTIFACTS_DIR / "clip_6frames.mp4"
out = tmp_path / "trim_window.mp4"
cfg = VideoEncoderConfig(vcodec="h264")
reencode_video(src, out, camera_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True)
with av.open(str(out)) as container:
frames = list(container.decode(video=0))
# Only the frames at 0.067 and 0.1 s fall inside [0.05, 0.12).
assert len(frames) == 2
assert frames[0].time == pytest.approx(0.0, abs=1e-3)
class TestConcatenateVideoFiles:
def test_two_clips_frame_count(self, tmp_path):
@@ -0,0 +1,167 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention-masking tests for the PI052 (π0.5 v2) text head.
Regression coverage for the text-CE collapse bug: PaliGemma's
``embed_prefix`` flags every language token ``att=0``, which
``make_att_2d_masks`` turns into one fully *bidirectional* block. Under
that mask the text cross-entropy degenerates into a copy task a
supervised target token attends to the tokens it is trained to predict
and the LM head never learns causal generation, so ``select_message``
collapses at inference.
``_mark_target_span_causal`` sets ``att=1`` on the supervised target
language positions so each target token attends causally among the
targets while staying bidirectional to images + the user prompt. These
tests pin that behaviour for the PaliGemma prefix layout.
"""
import pytest
import torch
# modeling_pi052 / modeling_pi05 import transformers transitively.
pytest.importorskip("transformers")
from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
from lerobot.policies.pi052.modeling_pi052 import ( # noqa: E402
_mark_target_span_causal,
_shifted_lin_ce,
)
def _shifted_ce(logits, labels):
"""Adapter: ``_shifted_lin_ce`` is Liger-fused (hidden @ lm_head_weightᵀ).
An identity ``lm_head_weight`` makes the computed logits equal ``logits``.
Liger's Triton kernel is GPU-only, so inputs run on CUDA; the loss is
returned on CPU so grad still flows back to the CPU ``logits`` leaf.
"""
if not torch.cuda.is_available():
pytest.skip("Liger fused CE requires CUDA")
vocab_size = logits.shape[-1]
eye = torch.eye(vocab_size, dtype=logits.dtype, device="cuda")
return _shifted_lin_ce(logits.cuda(), eye, labels.cuda()).cpu()
# ---------------------------------------------------------------------------
# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
#
# indices 0-1 : 2 image tokens (att = 0)
# indices 2-4 : 3 user-prompt lang (att = 0)
# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
#
# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
# real ids on the 4-token target span. PaliGemma's prefix has no state
# token (unlike SmolVLA), so the lang span ends at the prefix end.
# ---------------------------------------------------------------------------
N_IMAGE = 2
N_PROMPT = 3
N_TARGET = 4
LANG_START = N_IMAGE
LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = prefix length
PREFIX_LEN = LANG_END
def _embed_prefix_att_masks() -> torch.Tensor:
"""Mimic PaliGemma ``embed_prefix``: images + lang all att=0."""
return torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
def _text_labels() -> torch.Tensor:
"""-100 over the prompt span, real ids over the target span."""
labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
return labels
def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
"""2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
return make_att_2d_masks(pad, prefix_att_masks)[0]
def test_mark_sets_att_on_targets_only():
"""Only the supervised target language positions flip to att=1."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
expected = [False] * PREFIX_LEN
for i in range(LANG_START + N_PROMPT, LANG_END): # target span
expected[i] = True
assert marked[0].tolist() == expected
def test_target_tokens_attend_causally_among_themselves():
"""A target token must NOT attend to later targets, but must attend
to earlier ones genuine causal next-token prediction."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
attends = _attends(marked)
tgt = range(LANG_START + N_PROMPT, LANG_END)
for i in tgt:
for j in tgt:
if j > i:
assert not attends[i, j], f"target {i} must not see future target {j}"
else:
assert attends[i, j], f"target {i} must see earlier/self target {j}"
def test_target_tokens_attend_prompt_and_images_bidirectionally():
"""Targets keep full visibility of images + the user prompt."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
attends = _attends(marked)
context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
for i in range(LANG_START + N_PROMPT, LANG_END):
for j in context:
assert attends[i, j], f"target {i} must attend context {j}"
def test_non_target_subtask_stays_bidirectional():
"""A flow-only / non-target language span (all -100 labels) leaves the
mask untouched the action expert reads it bidirectionally."""
all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
)
assert torch.equal(marked, _embed_prefix_att_masks())
def test_unmarked_mask_is_bidirectional_the_bug():
"""Documents the bug the fix prevents: without ``_mark_target_span_causal``
a target token attends *bidirectionally* to later targets the
text-CE can copy the answer it is trained to predict."""
attends = _attends(_embed_prefix_att_masks())
first_tgt = LANG_START + N_PROMPT
last_tgt = LANG_END - 1
assert attends[first_tgt, last_tgt], (
"raw embed_prefix mask is bidirectional over language — the first "
"target token can see the last, which is the collapse bug"
)
def test_shifted_ce_returns_zero_when_no_text_positions_are_supervised():
pytest.importorskip("liger_kernel")
logits = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.full((2, 4), -100, dtype=torch.long)
loss = _shifted_ce(logits, labels)
assert loss.item() == 0
loss.backward()
assert logits.grad is not None
@@ -0,0 +1,114 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Regression tests for PI052 FAST action-code supervision."""
import pytest
import torch
from torch.nn import functional as F
pytest.importorskip("transformers")
pytest.importorskip("liger_kernel")
from lerobot.policies.pi052.modeling_pi052 import _fast_lin_ce # noqa: E402
def _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t):
"""Adapter: ``_fast_lin_ce`` is Liger-fused (hidden @ lm_head_weightᵀ).
Feeding an identity ``lm_head_weight`` makes the computed logits equal the
provided ``logits``, so these regression tests exercise the masking/gating
logic exactly as before the fused-CE refactor. Liger's Triton kernel is
GPU-only, so inputs are moved to CUDA and the loss is returned on CPU
(keeping grad flowing back to the CPU ``logits`` leaf).
"""
if not torch.cuda.is_available():
pytest.skip("Liger fused CE requires CUDA")
vocab_size = logits.shape[-1]
eye = torch.eye(vocab_size, dtype=logits.dtype, device="cuda")
predict = predict_actions_t.cuda() if predict_actions_t is not None else None
loss = _fast_lin_ce(
logits.cuda(), eye, action_tokens.cuda(), action_code_mask.cuda(), predict
)
return loss.cpu()
def test_fast_ce_supervises_only_discrete_action_codes():
"""Wrapper tokens can be wrong without affecting the FAST action-code loss."""
vocab_size = 8
action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]])
action_code_mask = torch.tensor([[False, False, True, True, False, False]])
logits = torch.zeros(1, action_tokens.shape[1], vocab_size)
# Deliberately bad wrapper-token predictions. These should be ignored.
logits[0, 0, 7] = 10.0 # target would be token 2
logits[0, 3, 7] = 10.0 # target would be delimiter token 5
# Correct action-code predictions: hidden t predicts target t + 1.
logits[0, 1, 3] = 10.0
logits[0, 2, 4] = 10.0
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
expected = F.cross_entropy(
torch.stack([logits[0, 1], logits[0, 2]]),
torch.tensor([3, 4]),
reduction="mean",
)
# Looser tolerance: the fused Triton kernel (GPU) differs from CPU eager
# F.cross_entropy at the ~1e-7 level, which exceeds the default rtol on
# these very small (~1e-4) losses.
assert torch.allclose(loss, expected, atol=1e-5, rtol=1e-3)
def test_fast_ce_masks_non_action_samples():
"""Recipe samples with predict_actions=False do not contribute FAST loss."""
vocab_size = 8
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
action_code_mask = torch.tensor(
[[False, False, True, True], [False, False, True, True]]
)
predict_actions = torch.tensor([True, False])
logits = torch.zeros(2, action_tokens.shape[1], vocab_size)
logits[0, 1, 3] = 10.0
logits[0, 2, 4] = 10.0
# Bad predictions in the masked sample should not matter.
logits[1, 1, 7] = 10.0
logits[1, 2, 7] = 10.0
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions)
expected = F.cross_entropy(
torch.stack([logits[0, 1], logits[0, 2]]),
torch.tensor([3, 4]),
reduction="mean",
)
# Looser tolerance: the fused Triton kernel (GPU) differs from CPU eager
# F.cross_entropy at the ~1e-7 level, which exceeds the default rtol on
# these very small (~1e-4) losses.
assert torch.allclose(loss, expected, atol=1e-5, rtol=1e-3)
def test_fast_ce_returns_zero_when_no_action_code_positions_are_valid():
logits = torch.randn(2, 4, 8, requires_grad=True)
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
action_code_mask = torch.zeros_like(action_tokens, dtype=torch.bool)
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
assert loss.item() == 0
loss.backward()
assert logits.grad is not None
@@ -0,0 +1,153 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Numerical-parity tests for the SDPA attention port.
``pi05`` / ``pi052`` replaced the per-layer call from
``modeling_gemma.eager_attention_forward`` with
``sdpa_attention_forward`` (PyTorch SDPA + GQA repeat). The forward
output must be bit-equivalent (within bf16 tolerance) on the masks
this model actually uses block-bidirectional with an arbitrary
additive bias otherwise we silently change training behaviour.
"""
from types import SimpleNamespace
import pytest
import torch
pytest.importorskip("transformers")
from transformers.models.gemma import modeling_gemma # noqa: E402
from lerobot.policies.pi052.modeling_pi052 import make_att_2d_masks # noqa: E402
from lerobot.policies.pi_gemma import sdpa_attention_forward # noqa: E402
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE # noqa: E402
def _mock_self_attn(num_kv_groups: int, training: bool = False):
"""Bare module surface that both forwards read."""
return SimpleNamespace(
num_key_value_groups=num_kv_groups,
training=training,
)
def _build_inputs(
bsize: int,
num_heads: int,
num_kv_heads: int,
seq_len: int,
head_dim: int,
dtype: torch.dtype,
seed: int = 0,
):
g = torch.Generator(device="cpu").manual_seed(seed)
q = torch.randn(bsize, num_heads, seq_len, head_dim, dtype=dtype, generator=g)
k = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
v = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
return q, k, v
def _block_bidirectional_mask(
bsize: int, seq_len: int, block_sizes: list[int], dtype: torch.dtype
) -> torch.Tensor:
"""Mimic ``_prepare_attention_masks_4d`` on a block layout that
matches ``[images, language, suffix]`` from ``embed_prefix`` +
``embed_suffix``: every block bidirectional internally, later
blocks visible to earlier ones via the cumulative-block rule.
"""
assert sum(block_sizes) == seq_len
att_marks = []
for i, n in enumerate(block_sizes):
att_marks += [1 if i > 0 else 0] + [0] * (n - 1)
pad = torch.ones(bsize, seq_len, dtype=torch.bool)
att = torch.tensor(att_marks, dtype=torch.bool)[None].expand(bsize, seq_len)
att_2d = make_att_2d_masks(pad, att)
bias = torch.where(
att_2d[:, None, :, :],
torch.zeros((), dtype=dtype),
torch.tensor(OPENPI_ATTENTION_MASK_VALUE, dtype=dtype),
)
return bias
@pytest.mark.parametrize(
"num_heads,num_kv_heads,head_dim",
[
(8, 1, 256), # gemma_2b / paligemma config
(8, 8, 64), # MHA control (no GQA repeat)
],
)
def test_sdpa_parity_with_eager_block_bidirectional(num_heads, num_kv_heads, head_dim):
"""SDPA forward output matches the eager softmax(QK^T)@V on the
block-bidirectional mask layout pi05 actually uses."""
bsize, seq_len = 2, 13
block_sizes = [4, 5, 4] # images, language, suffix-style blocks
dtype = torch.float32 # cpu math kernel — keep fp32 for tight tol
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, dtype)
mask = _block_bidirectional_mask(bsize, seq_len, block_sizes, dtype)
module = _mock_self_attn(num_heads // num_kv_heads)
out_eager, _ = modeling_gemma.eager_attention_forward(
module, q, k, v, mask, scaling
)
out_sdpa, _ = sdpa_attention_forward(
module, q, k, v, mask, scaling
)
assert out_eager.shape == out_sdpa.shape
torch.testing.assert_close(out_sdpa, out_eager, atol=1e-5, rtol=1e-4)
def test_sdpa_parity_bf16():
"""bf16 path — looser tolerance, must still match eager."""
bsize, num_heads, num_kv_heads, seq_len, head_dim = 2, 8, 1, 17, 256
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.bfloat16)
mask = _block_bidirectional_mask(bsize, seq_len, [5, 6, 6], torch.bfloat16)
module = _mock_self_attn(num_heads // num_kv_heads)
out_eager, _ = modeling_gemma.eager_attention_forward(
module, q, k, v, mask, scaling
)
out_sdpa, _ = sdpa_attention_forward(
module, q, k, v, mask, scaling
)
torch.testing.assert_close(out_sdpa, out_eager, atol=2e-2, rtol=2e-2)
def test_sdpa_parity_backward():
"""Gradients flow through SDPA and match the eager path within
bf16 tolerance critical for any training-side parity claim."""
bsize, num_heads, num_kv_heads, seq_len, head_dim = 1, 4, 2, 9, 32
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.float32)
q.requires_grad_(True); k.requires_grad_(True); v.requires_grad_(True)
mask = _block_bidirectional_mask(bsize, seq_len, [3, 3, 3], torch.float32)
module = _mock_self_attn(num_heads // num_kv_heads)
out_e, _ = modeling_gemma.eager_attention_forward(module, q, k, v, mask, scaling)
g_q_e, g_k_e, g_v_e = torch.autograd.grad(out_e.sum(), [q, k, v])
out_s, _ = sdpa_attention_forward(module, q, k, v, mask, scaling)
g_q_s, g_k_s, g_v_s = torch.autograd.grad(out_s.sum(), [q, k, v])
torch.testing.assert_close(g_q_s, g_q_e, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(g_k_s, g_k_e, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(g_v_s, g_v_e, atol=1e-5, rtol=1e-4)
@@ -0,0 +1,196 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PI052's text tokenizer.
Covers ``say`` tool-call flattening (PaliGemma's flat prompt has no
structured tool calls, so a ``say`` call must be serialized into a
``<say>...</say>`` text marker) and EOS-termination supervision (the
supervised target span must end with an EOS token so the LM head learns
to stop instead of rambling to ``max_length`` at inference).
"""
import torch
from lerobot.policies.pi052.text_processor_pi052 import (
PI052TextTokenizerStep,
_flatten_say_tool_calls,
_format_messages,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
def _say_call(text):
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
def test_flatten_appends_say_marker_and_drops_tool_calls():
msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]}
out = _flatten_say_tool_calls(msg)
assert "tool_calls" not in out
assert out["content"] == "Heading to the cube.\n<say>On it!</say>"
def test_flatten_marker_only_when_content_empty_or_none():
out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]})
assert out["content"] == "<say>hi</say>"
def test_flatten_accepts_json_string_arguments():
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
assert out["content"] == "p\n<say>hello there</say>"
def test_flatten_leaves_messages_without_tool_calls_untouched():
msg = {"role": "assistant", "content": "just a plan"}
assert _flatten_say_tool_calls(msg) == msg
def test_flatten_drops_non_say_tool_calls_but_keeps_content():
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
out = _flatten_say_tool_calls(
{"role": "assistant", "content": "plan only", "tool_calls": [weather]}
)
assert out["content"] == "plan only"
assert "tool_calls" not in out
# ---------------------------------------------------------------------------
# EOS-termination supervision
# ---------------------------------------------------------------------------
def test_format_messages_appends_eos_to_target_turns_only():
msgs = [
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
]
prompt, spans = _format_messages(msgs, target_indices=[1], eos_token="<eos>")
# EOS is appended to the supervised target (assistant) turn only.
assert prompt == "User: pick cube\nAssistant: move to cube<eos>\n"
# The user span is unchanged; the target span covers content + EOS.
assert prompt[spans[0][0] : spans[0][1]] == "pick cube"
assert prompt[spans[1][0] : spans[1][1]] == "move to cube<eos>"
def test_format_messages_without_eos_args_is_unchanged():
"""Inference callers omit target_indices / eos_token — no EOS baked in."""
prompt, spans = _format_messages([{"role": "user", "content": "hi"}])
assert prompt == "User: hi\n"
assert prompt[spans[0][0] : spans[0][1]] == "hi"
def _eos_char_id() -> int:
"""Token id _CharTokenizer assigns to its 1-char EOS."""
return ord("\x1f") % 251 + 1
def test_pi052_text_tokenizer_supervises_eos_at_target_end():
"""The appended EOS is the last supervised label on a target turn —
that's the signal that teaches the LM head to stop. The trailing
newline right after it stays unsupervised (-100)."""
step = PI052TextTokenizerStep(max_length=64)
step._tokenizer = _CharTokenizer()
transition = {
TransitionKey.OBSERVATION: {},
TransitionKey.COMPLEMENTARY_DATA: {
"messages": [
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
],
"target_message_indices": [1],
"message_streams": ["high_level", "high_level"],
"index": torch.tensor(10),
},
}
out = step(transition)
ids = out[TransitionKey.OBSERVATION][OBS_LANGUAGE_TOKENS][0]
labels = out[TransitionKey.COMPLEMENTARY_DATA]["text_labels"][0]
supervised = (labels != -100).nonzero().flatten().tolist()
assert supervised, "target turn produced no supervised labels"
last = supervised[-1]
# The last supervised token is the appended EOS.
assert int(ids[last]) == _eos_char_id()
assert int(labels[last]) == _eos_char_id()
# The token right after the EOS (the trailing newline) is NOT supervised.
assert int(labels[last + 1]) == -100
class _CharTokenizer:
pad_token_id = 0
eos_token = "\x1f" # unit separator — a 1-char "EOS" for testing
def __call__(
self,
text,
max_length,
padding,
truncation,
return_tensors,
return_offsets_mapping,
padding_side,
):
ids = [ord(c) % 251 + 1 for c in text[:max_length]]
offsets = [(i, i + 1) for i in range(len(ids))]
attention = [1] * len(ids)
if padding == "max_length" and len(ids) < max_length:
pad = max_length - len(ids)
ids += [self.pad_token_id] * pad
offsets += [(0, 0)] * pad
attention += [0] * pad
return {
"input_ids": torch.tensor([ids], dtype=torch.long),
"attention_mask": torch.tensor([attention], dtype=torch.long),
"offset_mapping": torch.tensor([offsets], dtype=torch.long),
}
def decode(self, token_ids, skip_special_tokens=False):
return "".join(chr(max(int(i) - 1, 0)) for i in token_ids if int(i) != self.pad_token_id)
def test_pi052_text_tokenizer_handles_batched_rendered_messages():
step = PI052TextTokenizerStep(max_length=64)
step._tokenizer = _CharTokenizer()
transition = {
TransitionKey.OBSERVATION: {},
TransitionKey.COMPLEMENTARY_DATA: {
"messages": [
[
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
],
[{"role": "user", "content": "open drawer"}],
],
"target_message_indices": [[1], []],
"message_streams": [["high_level", "high_level"], ["low_level"]],
"index": torch.tensor([10, 11]),
},
}
out = step(transition)
obs = out[TransitionKey.OBSERVATION]
comp = out[TransitionKey.COMPLEMENTARY_DATA]
assert obs[OBS_LANGUAGE_TOKENS].shape == (2, 64)
assert obs[OBS_LANGUAGE_ATTENTION_MASK].shape == (2, 64)
assert comp["text_labels"].shape == (2, 64)
assert comp["predict_actions"].tolist() == [False, True]
assert (comp["text_labels"][0] != -100).any()
assert not (comp["text_labels"][1] != -100).any()
+187
View File
@@ -0,0 +1,187 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
reuses the detection prior instead of fighting it (the ``<loc>``-salad
bug). The dataset stores Qwen2.5-VL's grounding output — **01000
normalized** coordinates, *not* pixels. (Verified empirically on the
published datasets: x and y both span 0..1000 with ~30% of values
exceeding the camera's pixel dimensions.) The conversion is therefore
camera-resolution-independent. The dataset stays backbone-agnostic
JSON; the conversion lives in PI052's tokenizer. These tests pin the
JSON ``<loc>`` rewrite.
"""
import pytest
pytest.importorskip("transformers")
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
_loc_token,
_messages_vqa_to_loc,
_vqa_answer_to_loc,
register_paligemma_loc_tokens,
)
class _FakeTokenizer:
"""Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads."""
def __init__(self, prepopulate: bool = False):
self.added_tokens_encoder: dict[str, int] = {}
self.calls: list[list[str]] = []
if prepopulate:
self.added_tokens_encoder["<loc0000>"] = 256000
def add_tokens(self, tokens: list[str]) -> int:
self.calls.append(list(tokens))
for t in tokens:
self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000)
return len(tokens)
def test_register_loc_tokens_adds_full_1024_range():
tok = _FakeTokenizer()
out = register_paligemma_loc_tokens(tok)
assert out is tok # returns same instance
assert len(tok.calls) == 1
added = tok.calls[0]
assert len(added) == 1024
assert added[0] == "<loc0000>"
assert added[-1] == "<loc1023>"
# Spot check a few in the middle.
assert added[162] == "<loc0162>"
assert added[759] == "<loc0759>"
def test_register_loc_tokens_is_idempotent():
"""If the loc tokens are already present we skip re-adding them."""
tok = _FakeTokenizer(prepopulate=True)
register_paligemma_loc_tokens(tok)
register_paligemma_loc_tokens(tok)
assert tok.calls == [] # never called add_tokens
def test_loc_token_normalizes_and_clamps():
# Default scale is the 01000 Qwen convention.
assert _loc_token(0) == "<loc0000>"
assert _loc_token(1000) == "<loc1023>"
assert _loc_token(500) == f"<loc{round(500 / 1000 * 1023):04d}>"
# out-of-range coordinates clamp into [0, 1023]
assert _loc_token(9999) == "<loc1023>"
assert _loc_token(-5) == "<loc0000>"
def test_vqa_answer_to_loc_keypoint_normalized():
# Label-first: avoids the "Assistant: → <loc>" attractor at training.
answer = {"label": "blue cube", "point_format": "xy", "point": [500, 500]}
assert _vqa_answer_to_loc(answer) == "blue cube <loc0512><loc0512>"
def test_vqa_answer_to_loc_bbox_normalized():
answer = {
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 1000, 1000]}]
}
assert _vqa_answer_to_loc(answer) == "cube <loc0000><loc0000><loc1023><loc1023>"
def test_vqa_answer_to_loc_multiple_detections_separator():
answer = {
"detections": [
{"label": "blue", "bbox_format": "xyxy", "bbox": [0, 0, 500, 500]},
{"label": "yellow", "bbox_format": "xyxy", "bbox": [500, 500, 1000, 1000]},
]
}
out = _vqa_answer_to_loc(answer)
# Each segment is "label <locs>", joined by " ; "
assert out == (
"blue <loc0000><loc0000><loc0512><loc0512> ; "
"yellow <loc0512><loc0512><loc1023><loc1023>"
)
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}) is None
assert _vqa_answer_to_loc({"weird": "payload"}) is None
def test_messages_vqa_to_loc_rewrites_target_turn():
messages = [
{"role": "user", "content": [{"type": "text", "text": "where is the cube?"}]},
{
"role": "assistant",
"content": '{"label": "cube", "point_format": "xy", "point": [500, 500]}',
},
]
out = _messages_vqa_to_loc(messages, target_indices=[1])
assert out[1]["content"] == "cube <loc0512><loc0512>"
# input messages are not mutated
assert messages[1]["content"].startswith("{")
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
messages = [
{"role": "user", "content": "pick the cube"},
{"role": "assistant", "content": "pick up the cube"},
]
out = _messages_vqa_to_loc(messages, target_indices=[1])
assert out[1]["content"] == "pick up the cube"
def test_messages_vqa_to_loc_noop_without_target_indices():
messages = [
{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}
]
assert _messages_vqa_to_loc(messages, []) is messages
# ---------------------------------------------------------------------------
# Round-trip: training-side JSON -> <loc> -> runtime-side parse back
#
# Pins that the conversion preserves coordinate *order* (JSON is x-first,
# PaliGemma <loc> is y-first) and the 01000 → [0, 1023] scaling. The
# only loss is quantization to the 1024-bucket <loc> grid, so a coord
# survives within half a bucket (~1000/2046 ≈ 0.49 on the 01000 scale).
# ---------------------------------------------------------------------------
def test_loc_round_trip_keypoint_preserves_normalized_coords():
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
loc = _vqa_answer_to_loc(answer)
parsed = parse_vqa_answer(loc)
nx, ny = parsed["payload"]["point"]
# parse_vqa_answer returns [0, 1] normalized; rescale back to 01000.
assert abs(nx * 1000.0 - 640) <= 1000.0 / 2046 + 1e-6
assert abs(ny * 1000.0 - 480) <= 1000.0 / 2046 + 1e-6
assert parsed["payload"]["label"] == "blue cube"
def test_loc_round_trip_bbox_preserves_order_and_scale():
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
answer = {
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
}
loc = _vqa_answer_to_loc(answer)
parsed = parse_vqa_answer(loc)
x1, y1, x2, y2 = parsed["payload"]["detections"][0]["bbox"]
for got, want in ((x1, 100), (y1, 200), (x2, 800), (y2, 900)):
assert abs(got * 1000.0 - want) <= 1000.0 / 2046 + 1e-6
-220
View File
@@ -24,7 +24,6 @@ from typing import Any
import pytest
import torch
import torch.nn as nn
from safetensors.torch import load_file
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
@@ -175,53 +174,6 @@ class MockStepWithTensorState(ProcessorStep):
return features
class MockLazyTensorStateStep(ProcessorStep):
"""Mock step whose tensor state is not present in constructor config."""
def __init__(
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
):
self.name = name
self.scale = scale
self.tensor_state: torch.Tensor | None = None
if initial_value is not None:
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Return the transition unchanged."""
return transition
def get_config(self) -> dict[str, Any]:
"""Return constructor config while intentionally omitting tensor state."""
return {
"name": self.name,
"scale": self.scale,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return tensor state only after it has been initialized or loaded."""
if self.tensor_state is None:
return {}
return {"tensor_state": self.tensor_state}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load tensor state."""
self.tensor_state = state["tensor_state"].clone()
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Return features unchanged."""
return features
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
"""Registered lazy tensor state step for registry-based serialization tests."""
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
@@ -668,178 +620,6 @@ def test_mixed_json_and_tensor_state():
assert torch.allclose(loaded_step.running_mean, step.running_mean)
def test_get_config_matches_saved_json():
"""Test that in-memory config matches the config written by save_pretrained."""
stateless_step = MockStep(name="stateless")
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
in_memory_config = pipeline.get_config()
assert pipeline.get_config() == in_memory_config
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
config_path = Path(tmp_dir) / "memory_pipeline.json"
with open(config_path) as file_pointer:
saved_config = json.load(file_pointer)
assert in_memory_config == saved_config
assert "state_file" not in in_memory_config["steps"][0]
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
def test_state_dict_matches_saved_safetensors():
"""Test that in-memory state matches the safetensors written by save_pretrained."""
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
in_memory_state_dict = pipeline.state_dict()
state_filename = "stateful_pipeline_step_0.safetensors"
state_key = "stateful_pipeline_step_0"
assert set(in_memory_state_dict) == {state_key}
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
in_memory_state_dict[state_key]["tensor_state"].add_(1)
assert stateful_step.tensor_state is not None
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
def test_save_pretrained_still_writes_expected_serialization_files():
"""Test that save_pretrained keeps the existing config and state filenames."""
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
save_path = Path(tmp_dir)
assert (save_path / "policy_preprocessor.json").exists()
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
def test_from_config_round_trips_stateful_pipeline():
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert len(loaded_pipeline) == 1
assert isinstance(loaded_step, MockLazyTensorStateStep)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
def test_from_config_round_trips_registered_stateful_pipeline():
"""Test that from_config resolves registry steps and loads their named tensor state."""
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
assert config["steps"][0]["state_file"] == state_filename
assert set(pipeline_state_dict) == {state_key}
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
assert loaded_step.tensor_state is not None
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
def test_from_config_preserves_state_metadata_for_empty_initial_state():
"""Test in-memory loading when rebuilt steps start without tensor state."""
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.state_dict() == {}
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
loaded_pipeline.load_state_dict(pipeline_state_dict)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
def test_from_config_applies_overrides_before_state_loading():
"""Test that constructor overrides and tensor state loading are separate operations."""
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(
config,
state_dict=pipeline_state_dict,
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.scale == 5.0
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
def test_load_state_dict_raises_on_missing_expected_state():
"""Test loading raises when serialized config expects missing state."""
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
loaded_pipeline.load_state_dict({})
def test_load_state_dict_raises_on_unexpected_extra_state():
"""Test loading raises on unexpected top-level state keys."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
with pytest.raises(KeyError, match="extra"):
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
"""Test stateless in-memory serialization and loading."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
config = pipeline.get_config()
config_without_name = {"steps": config["steps"]}
assert pipeline.state_dict() == {}
assert all("state_file" not in step_entry for step_entry in config["steps"])
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
assert loaded_pipeline.name == "DataProcessorPipeline"
assert loaded_pipeline.state_dict() == {}
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
def test_from_config_rejects_non_dict_config(invalid_config):
"""Test from_config reports invalid top-level config values cleanly."""
with pytest.raises(ValueError, match="not a valid processor configuration"):
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
class MockModuleStep(ProcessorStep, nn.Module):
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
@@ -58,3 +58,70 @@ def test_render_messages_step_renders_and_drops_raw_language():
assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1]
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
complementary_data={
"task": "pick the cube",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [],
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
assert data["message_streams"] == ["low_level"]
assert data["target_message_indices"] == []
def test_render_messages_step_falls_back_per_sample_in_batched_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
action=torch.arange(4).reshape(2, 2),
complementary_data={
"task": ["pick the cube", "open the drawer"],
"timestamp": torch.tensor([0.0, 1.0]),
"index": torch.tensor([7, 8]),
"language_persistent": [[], []],
"language_events": [
[{"style": "unmatched", "timestamp": 0.0}],
[{"style": "unmatched", "timestamp": 1.0}],
],
},
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [
[{"role": "user", "content": "pick the cube"}],
[{"role": "user", "content": "open the drawer"}],
]
assert data["message_streams"] == [["low_level"], ["low_level"]]
assert data["target_message_indices"] == [[], []]
@@ -66,20 +66,6 @@ class TestOperationTypeParsing:
with pytest.raises(ValueError, match="--new_repo_id is required for merge"):
_validate_config(cfg)
@pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"])
def test_merge_concatenate_flag_defaults_true(self, flag):
cfg = parse_cfg(["--new_repo_id", "test/merged", "--operation.type", "merge"])
assert isinstance(cfg.operation, MergeConfig)
assert getattr(cfg.operation, flag) is True
@pytest.mark.parametrize("flag", ["concatenate_videos", "concatenate_data"])
def test_merge_concatenate_flag_can_be_disabled(self, flag):
cfg = parse_cfg(
["--new_repo_id", "test/merged", "--operation.type", "merge", f"--operation.{flag}", "false"]
)
assert isinstance(cfg.operation, MergeConfig)
assert getattr(cfg.operation, flag) is False
def test_non_merge_requires_repo_id(self):
cfg = parse_cfg(["--operation.type", "delete_episodes"])
with pytest.raises(ValueError, match="--repo_id is required for delete_episodes"):
+1 -29
View File
@@ -1,19 +1,5 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from types import SimpleNamespace
@@ -42,14 +28,6 @@ def test_push_to_hub_tags_uploaded_dataset_revision(tmp_path, monkeypatch):
calls["upload_folder"] = kwargs
return SimpleNamespace(oid="abc123")
def delete_tag(self, repo_id, **kwargs):
import requests
from huggingface_hub.errors import RevisionNotFoundError
calls["delete_tag"] = {"repo_id": repo_id, **kwargs}
# Simulate the common case: no stale tag to delete.
raise RevisionNotFoundError("no such tag", response=requests.Response())
def create_tag(self, **kwargs):
calls["create_tag"] = kwargs
@@ -71,16 +49,10 @@ def test_push_to_hub_tags_uploaded_dataset_revision(tmp_path, monkeypatch):
"exist_ok": True,
}
assert calls["upload_folder"]["repo_id"] == "annotated/dataset"
# A stale tag (e.g. from a previous annotation run) is deleted first so
# the new tag always points at the upload we just made.
assert calls["delete_tag"] == {
"repo_id": "annotated/dataset",
"tag": "v3.0",
"repo_type": "dataset",
}
assert calls["create_tag"] == {
"repo_id": "annotated/dataset",
"tag": "v3.0",
"repo_type": "dataset",
"exist_ok": True,
"revision": "abc123",
}
+2 -2
View File
@@ -134,7 +134,7 @@ class TestMultiGPUTraining:
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=10",
"--env_eval_freq=-1",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
@@ -177,7 +177,7 @@ class TestMultiGPUTraining:
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=20",
"--env_eval_freq=-1",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
+1 -77
View File
@@ -15,7 +15,6 @@
# limitations under the License.
import pytest
import torch
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
@@ -26,16 +25,8 @@ def mock_metrics():
class MockAccelerator:
def __init__(self, num_processes: int, reduce_fn=None):
def __init__(self, num_processes: int):
self.num_processes = num_processes
self.device = torch.device("cpu")
self._reduce_fn = reduce_fn
def reduce(self, tensor, reduction="mean"):
# In single-process tests we just want a deterministic stand-in for accelerate's reduce.
if self._reduce_fn is not None:
return self._reduce_fn(tensor, reduction)
return tensor
def test_average_meter_initialization():
@@ -166,70 +157,3 @@ def test_metrics_tracker_reset_averages(mock_metrics):
tracker.reset_averages()
assert tracker.loss.avg == 0.0
assert tracker.accuracy.avg == 0.0
def test_average_meter_invalid_reduction():
with pytest.raises(ValueError):
AverageMeter("loss", reduction="median")
def test_average_meter_reduction_stored():
meter = AverageMeter("updt_s", reduction="max")
assert meter.reduction == "max"
def test_metrics_tracker_reduce_across_ranks_no_accelerator():
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics)
tracker.update_s = 0.5
tracker.reduce_across_ranks() # no-op without accelerator
assert tracker.update_s.avg == 0.5
def test_metrics_tracker_reduce_across_ranks_single_process():
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=metrics,
accelerator=MockAccelerator(num_processes=1),
)
tracker.update_s = 0.5
tracker.reduce_across_ranks() # no-op when world size is 1
assert tracker.update_s.avg == 0.5
def test_metrics_tracker_reduce_across_ranks_invokes_reduce():
captured = {}
def fake_reduce(tensor, reduction):
captured["reduction"] = reduction
captured["values"] = tensor.clone()
# Pretend the slowest rank reported 0.9 instead of this rank's 0.4.
return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device)
metrics = {
"loss": AverageMeter("loss"), # reduction="none" -> not touched
"update_s": AverageMeter("update_s", reduction="max"),
}
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=metrics,
accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce),
)
tracker.loss = 1.0
tracker.update_s = 0.4
tracker.reduce_across_ranks()
assert captured["reduction"] == "max"
assert torch.allclose(captured["values"], torch.tensor([0.4]))
assert tracker.update_s.avg == pytest.approx(0.9)
# Metrics without a reduction stay untouched.
assert tracker.loss.avg == 1.0
# Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls
# accumulate against the cluster view rather than the stale per-rank sum.
meter = tracker.update_s
assert meter.sum / meter.count == pytest.approx(meter.avg)
-24
View File
@@ -20,8 +20,6 @@ from unittest.mock import Mock, patch
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
load_training_step,
save_checkpoint,
@@ -65,28 +63,6 @@ def test_load_training_step(tmp_path):
assert loaded_step == step
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
assert load_training_num_processes(tmp_path) == 4
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the world size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_num_processes(tmp_path) is None
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
assert load_training_batch_size(tmp_path) == 32
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the batch size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_batch_size(tmp_path) is None
def test_update_last_checkpoint(tmp_path):
checkpoint = tmp_path / "0005"
checkpoint.mkdir()

Some files were not shown because too many files have changed in this diff Show More