backend sampling: support returning post-sampling probs (#22622)

* server: Never return 0.0 post-sampling probabilities

* backend sampling: support returning post-sampling probs
This commit is contained in:
Tim Neumann
2026-05-10 19:12:02 +02:00
committed by GitHub
parent 5d5d2e15d2
commit 2e97c5f96f
4 changed files with 80 additions and 16 deletions
+9 -3
View File
@@ -1317,7 +1317,7 @@ private:
return false;
}
const bool need_logits = task.params.sampling.n_probs > 0;
const bool need_pre_sample_logits = task.params.sampling.n_probs > 0 && !task.params.post_sampling_probs;
bool backend_sampling = true;
@@ -1326,8 +1326,8 @@ private:
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
// TODO: getting pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_pre_sample_logits;
// TODO: tmp until backend sampling is fully implemented
if (backend_sampling) {
@@ -1504,6 +1504,12 @@ private:
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < n_probs; i++) {
// Some samplers do return 0.0 probabilities, others don't.
// Filter 0.0 probailities, to ensure the behavior is consistent.
if (cur_p->data[i].p == 0.0) {
break;
}
result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx, cur_p->data[i].id, special),
+60 -7
View File
@@ -491,29 +491,82 @@ def test_n_probs_post_sampling():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"prompt": "Today was the day. Today I would finally become a",
"n_probs": 10,
"temperature": 0.0,
"temperature": 1.0,
"n_predict": 5,
"post_sampling_probs": True,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
for (i, tok) in enumerate(res.body["completion_probabilities"]):
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_probs"]) == 10
assert "top_probs" in tok and type(tok["top_probs"]) == list
for prob in tok["top_probs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
# 0.0 probability tokens should never be returned by the server
assert "prob" in prob and 0.0 < prob["prob"] <= 1.0
assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
if i == 0:
# The prompt is vague enough that we should get at least 10 possibilities
# for the first token.
assert len(tok["top_probs"]) == 10
if len(tok["top_probs"]) < 10:
# Getting less than the requested number of probabilities should only happen
# if the ones we did get already sum to 1.0.
assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0)
def test_n_probs_post_backend_sampling():
"""Verify that the same probabilities are returned with and without backend sampling."""
global server
server.backend_sampling = True
server.start()
def make_request(backend_sampling):
n_predict = 20
res = server.make_request("POST", "/completion", data={
"prompt": "The countries of Europe, in random order, are:",
"n_probs": 10,
"n_predict": n_predict,
"post_sampling_probs": True,
"seed": 4242,
"backend_sampling": backend_sampling,
})
assert res.status_code == 200
total_probs = 0
completions = res.body["completion_probabilities"]
assert len(completions) == n_predict
for tok in completions:
# Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
# data.
tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0]
total_probs += len(tok["top_probs"])
# Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
assert total_probs >= 2 * n_predict
return completions
def verify_token(a, b):
assert a["id"] == b["id"]
assert a["token"] == b["token"]
assert a["bytes"] == b["bytes"]
assert a["prob"] == pytest.approx(b["prob"], abs=0.01)
for (a, b) in zip(make_request(True), make_request(False)):
verify_token(a, b)
assert len(a["top_probs"]) == len(b["top_probs"])
for (aa, bb) in zip(a["top_probs"], b["top_probs"]):
verify_token(aa, bb)
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
def test_logit_bias(tokenize, openai_style):
+3
View File
@@ -108,6 +108,7 @@ class ServerProcess:
no_cache_idle_slots: bool = False
log_path: str | None = None
webui_mcp_proxy: bool = False
backend_sampling: bool = False
gcp_compat: bool = False
# session variables
@@ -252,6 +253,8 @@ class ServerProcess:
server_args.append("--no-cache-idle-slots")
if self.webui_mcp_proxy:
server_args.append("--webui-mcp-proxy")
if self.backend_sampling:
server_args.append("--backend_sampling")
if self.gcp_compat:
env["AIP_MODE"] = "PREDICTION"