Skip to content

Fix sample_logits crash when top_k exceeds vocab size#1347

Merged
jlarson4 merged 1 commit into
TransformerLensOrg:devfrom
robbiebusinessacc:contrib/clamp-sample-logits-top-k
May 29, 2026
Merged

Fix sample_logits crash when top_k exceeds vocab size#1347
jlarson4 merged 1 commit into
TransformerLensOrg:devfrom
robbiebusinessacc:contrib/clamp-sample-logits-top-k

Conversation

@robbiebusinessacc
Copy link
Copy Markdown

sample_logits passes top_k straight to final_logits.topk(top_k). When
top_k is larger than the vocabulary size (e.g.
model.generate(top_k=100_000) on a small-vocab model), .topk() raises
RuntimeError: selected index k out of range.

HuggingFace's TopKLogitsWarper handles this by clamping (top_k = min(top_k, logits.size(-1))). This change does the same: after the existing top_k > 0
assertion, clamp top_k to the last-dim size before calling .topk().
Behaviour is unchanged whenever top_k <= vocab_size.

While here, I removed a stale #! TODO: Finish testing all the edge cases
scratch-code comment and replaced it with a one-line docstring note describing
the clamping behaviour, since this PR adds those edge-case tests.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature
    works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect
    backward compatibility

Notes

sample_logits had no test coverage before this. The PR adds tests for the
top_k > vocab_size clamp (the crash case), exact-vocab-size top_k, plus
existing-behaviour regression cases. Verified by reverting only the clamp
line, which reproduces the original RuntimeError: selected index k out of range; restoring it makes all tests pass. black, isort, and pycln are
clean.

sample_logits crashed with 'selected index k out of range' when top_k
exceeded the vocabulary size (reachable via model.generate(top_k=...)).
Clamp top_k to the vocab size, matching HuggingFace's TopKLogitsWarper.

Also add the first unit tests for sample_logits (top_k clamping, greedy
temperature=0, top_p, frequency penalty, repetition penalty), resolving
the standing in-code TODO.
@jlarson4
Copy link
Copy Markdown
Collaborator

Excellent work on this! Merging for inclusion in the next release

@jlarson4 jlarson4 merged commit cdfab1a into TransformerLensOrg:dev May 29, 2026
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants