mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Compare commits
6 Commits
a47e535b02
...
a764c3e1d6
| Author | SHA1 | Date | |
|---|---|---|---|
| a764c3e1d6 | |||
| b416f287f2 | |||
| aa749d4947 | |||
| 1394a6ab5d | |||
| db9118f16f | |||
| 7a945d7bdc |
@@ -23,6 +23,18 @@ token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
# --- Diversity knobs (Pi0.7-style prompt expansion) -----------------------
|
||||
# Bumped roughly 3x across the board to fight memorization on small datasets.
|
||||
# A single dataset trained for many epochs with deterministic atom wording
|
||||
# converges to perfect recall on training prompts but produces JSON-token
|
||||
# garbage at inference for any wording that drifts slightly. More atom
|
||||
# variants per episode + higher sampling temperature widens the training
|
||||
# distribution so the model has to actually use its language head, not
|
||||
# just memorize.
|
||||
#
|
||||
# Pushes to a *new* hub repo (``_tool3``) so the previous annotation pass
|
||||
# (``_tool2``) stays intact — re-train from scratch on the new dataset and
|
||||
# compare loss-curve shapes to verify the diversity bump is doing something.
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
@@ -41,19 +53,21 @@ CMD = (
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=256 "
|
||||
"--vlm.client_concurrency=128 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
"--executor.episode_parallelism=32 "
|
||||
"--vlm.temperature=0.7 "
|
||||
"--executor.episode_parallelism=16 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
"--vlm.camera_key=observation.images.wrist "
|
||||
"--module_1.frames_per_second=1.0 "
|
||||
"--module_1.use_video_url=true "
|
||||
"--module_1.use_video_url_fps=1.0 "
|
||||
"--module_1.derive_task_from_video=always "
|
||||
"--module_1.n_task_rephrasings=10 "
|
||||
"--module_3.K=1 "
|
||||
"--module_1.n_task_rephrasings=30 "
|
||||
"--module_2.max_interjections_per_episode=6 "
|
||||
"--module_3.K=3 "
|
||||
"--module_3.vqa_emission_hz=1.0 "
|
||||
"--push_to_hub=pepijn223/super_poulain_full_tool2"
|
||||
"--push_to_hub=pepijn223/super_poulain_full_tool3"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
|
||||
@@ -237,17 +237,24 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
msg = (
|
||||
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
|
||||
f"either doesn't exist on the Hub yet, or it was uploaded "
|
||||
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
|
||||
f" from huggingface_hub import HfApi\n"
|
||||
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
|
||||
)
|
||||
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
|
||||
# ``__init__`` indexes ``response.headers`` unconditionally on
|
||||
# current ``huggingface_hub`` versions. Constructing it without
|
||||
# a real ``Response`` object crashes with either
|
||||
# ``TypeError: missing 1 required keyword-only argument`` (old
|
||||
# builds) or ``AttributeError: 'NoneType' object has no attribute
|
||||
# 'headers'`` (new builds). Skip that path entirely — this isn't
|
||||
# really an HTTP error, it's a configuration issue — and raise a
|
||||
# plain ``RuntimeError`` so the message actually reaches the
|
||||
# caller.
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
@@ -271,6 +271,9 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
msg = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
||||
)
|
||||
if msg and _looks_like_gibberish(msg):
|
||||
push_log(state, f" [info] subtask gen rejected (gibberish): {msg[:60]!r}")
|
||||
return None
|
||||
if msg:
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
@@ -307,6 +310,9 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
||||
)
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
push_log(state, f" [info] memory gen rejected (gibberish): {new_memory[:60]!r}")
|
||||
return None
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
@@ -340,11 +346,16 @@ class UserInterjectionFwd(InferenceStep):
|
||||
if not out:
|
||||
push_log(state, " [info] plan/say gen produced no text this tick")
|
||||
return None
|
||||
if _looks_like_gibberish(out):
|
||||
push_log(state, f" [info] plan/say gen rejected (gibberish): {out[:60]!r}")
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
# carrying both plan text AND a `say` tool call. Look for a
|
||||
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||
# text → plan, no speech.
|
||||
plan_text, speech_text = _split_plan_and_say(out)
|
||||
if plan_text and _looks_like_gibberish(plan_text):
|
||||
plan_text = ""
|
||||
if plan_text:
|
||||
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||
if speech_text:
|
||||
@@ -390,6 +401,9 @@ class AskVQAFwd(InferenceStep):
|
||||
answer = _generate_with_policy(
|
||||
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
||||
)
|
||||
# VQA answers are intentionally JSON-like during training, so
|
||||
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||
# the answer as-is — the VQA panel line lets the user judge.
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
@@ -432,6 +446,38 @@ class DispatchToolCalls(InferenceStep):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _looks_like_gibberish(text: str) -> bool:
|
||||
"""Heuristically detect generation that's clearly off the rails.
|
||||
|
||||
Memorised models can collapse to dominant-mode outputs (often the
|
||||
JSON-token salad ``":":":":...`` from VQA training) when the prompt
|
||||
drifts even slightly from training distribution. If we accept those
|
||||
as new state, they pollute the next tick's prompt and cascade into
|
||||
worse outputs. Reject anything that looks pathological:
|
||||
|
||||
* empty / whitespace-only
|
||||
* mostly punctuation (``"``, ``:``, ``,``)
|
||||
* a single character repeated past the threshold
|
||||
* starts with ``":"`` and contains no letters
|
||||
|
||||
The thresholds are intentionally lenient — a real subtask like
|
||||
``"close the gripper"`` has ~70%+ alpha characters, while gibberish
|
||||
like ``":":":"`` has ~0%.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
stripped = text.strip()
|
||||
alpha = sum(1 for c in stripped if c.isalpha())
|
||||
if alpha < max(3, len(stripped) // 8):
|
||||
return True
|
||||
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
|
||||
return True
|
||||
# Single repeating char: e.g. ``""""""``
|
||||
if len(set(stripped)) <= 2 and len(stripped) > 4:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _control_context_messages(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
|
||||
@@ -141,6 +141,43 @@ def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||
)
|
||||
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||
|
||||
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
|
||||
# resolves the dataset revision via ``get_safe_version`` which scans
|
||||
# for tags like ``v3.0``; without a tag it raises
|
||||
# ``RevisionNotFoundError``. Read the version straight from the
|
||||
# dataset's own ``meta/info.json`` so we tag whatever the writer
|
||||
# actually wrote (no accidental drift if the codebase floor moves).
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
version_tag = CODEBASE_VERSION
|
||||
if info_path.exists():
|
||||
try:
|
||||
from lerobot.utils.io_utils import load_json # noqa: PLC0415
|
||||
|
||||
info = load_json(info_path)
|
||||
ds_version = info.get("codebase_version")
|
||||
if isinstance(ds_version, str) and ds_version.startswith("v"):
|
||||
version_tag = ds_version
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}", flush=True)
|
||||
try:
|
||||
api.create_tag(
|
||||
repo_id=repo_id,
|
||||
tag=version_tag,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
|
||||
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
|
||||
"Run: from huggingface_hub import HfApi; "
|
||||
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
annotate()
|
||||
|
||||
@@ -307,6 +307,71 @@ def _build_observation_provider(
|
||||
return _provider
|
||||
|
||||
|
||||
def _bootstrap_state_from_dataset(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
episode: int,
|
||||
start_frame: int,
|
||||
) -> dict[str, str]:
|
||||
"""Pull task / active plan / active memory / active subtask at ``start_frame``.
|
||||
|
||||
The model is heavily memorised on the exact training prompts the
|
||||
recipe rendered from this dataset (canonical task wording,
|
||||
persistent atoms emitted earlier in the episode). Reconstructing
|
||||
that state at REPL startup lets the runtime's first prompt line
|
||||
up with what training looked like — without it the model sees an
|
||||
out-of-distribution prompt and falls back to its dominant
|
||||
training mode (VQA JSON spam).
|
||||
"""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
|
||||
if len(ds) == 0:
|
||||
return {}
|
||||
idx = max(0, min(start_frame, len(ds) - 1))
|
||||
sample = ds[idx]
|
||||
|
||||
out: dict[str, str] = {}
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task.strip():
|
||||
out["task"] = task
|
||||
|
||||
persistent = sample.get("language_persistent") or []
|
||||
# ``persistent`` is the broadcast slice of the episode; pick the
|
||||
# *latest* row of each style whose ``timestamp`` is ≤ the
|
||||
# frame's timestamp (matches the renderer's ``active_at``
|
||||
# semantics).
|
||||
try:
|
||||
frame_ts = (
|
||||
float(sample["timestamp"])
|
||||
if not hasattr(sample["timestamp"], "item")
|
||||
else sample["timestamp"].item()
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
frame_ts = float("inf")
|
||||
|
||||
by_style: dict[str, tuple[float, str]] = {}
|
||||
for row in persistent:
|
||||
style = row.get("style")
|
||||
ts = row.get("timestamp")
|
||||
content = row.get("content")
|
||||
if not (style and content) or ts is None:
|
||||
continue
|
||||
try:
|
||||
ts_f = float(ts)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if ts_f > frame_ts:
|
||||
continue
|
||||
prev = by_style.get(style)
|
||||
if prev is None or ts_f >= prev[0]:
|
||||
by_style[style] = (ts_f, content)
|
||||
for style, (_, content) in by_style.items():
|
||||
if style in {"plan", "memory", "subtask"}:
|
||||
out[style] = content
|
||||
return out
|
||||
|
||||
|
||||
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||
"""Instantiate the tools declared on this dataset/policy."""
|
||||
if no_tts:
|
||||
@@ -364,6 +429,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
)
|
||||
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
bootstrap_state: dict[str, str] = {}
|
||||
if args.dataset_repo_id is not None:
|
||||
print(
|
||||
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
||||
@@ -379,6 +445,25 @@ def main(argv: list[str] | None = None) -> int:
|
||||
preprocessor=preprocessor,
|
||||
device=str(getattr(policy.config, "device", "cpu")),
|
||||
)
|
||||
# Pull the dataset's canonical task + the persistent atoms in
|
||||
# force at the chosen start frame. The model is heavily
|
||||
# memorised on the *exact* training prompts (task wording,
|
||||
# current plan, current memory) — feeding ad-hoc user
|
||||
# alternatives gives it nothing to recall against, so it
|
||||
# collapses to its dominant training mode (VQA JSON). Reading
|
||||
# the canonical state straight from the dataset gives the
|
||||
# runtime a starting point that lines up with training.
|
||||
bootstrap_state = _bootstrap_state_from_dataset(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
episode=args.dataset_episode,
|
||||
start_frame=args.dataset_start_frame,
|
||||
)
|
||||
if bootstrap_state.get("task") and not args.task:
|
||||
args.task = bootstrap_state["task"]
|
||||
print(
|
||||
f"[smolvla2] using canonical task from dataset: {args.task!r}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||
if tools:
|
||||
@@ -411,6 +496,17 @@ def main(argv: list[str] | None = None) -> int:
|
||||
)
|
||||
if args.task:
|
||||
runtime.set_task(args.task)
|
||||
# Bootstrap plan/memory from the dataset so the first prompt the
|
||||
# runtime builds matches what training rendered (task + active
|
||||
# plan + active memory). Without this the runtime starts with
|
||||
# plan/memory empty, which only matched the very-early frames in
|
||||
# training and is an out-of-distribution prompt for the rest.
|
||||
if bootstrap_state.get("plan"):
|
||||
runtime.state["current_plan"] = bootstrap_state["plan"]
|
||||
if bootstrap_state.get("memory"):
|
||||
runtime.state["current_memory"] = bootstrap_state["memory"]
|
||||
if bootstrap_state.get("subtask"):
|
||||
runtime.state["current_subtask"] = bootstrap_state["subtask"]
|
||||
|
||||
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user