common : delegate assistant continuation to underlying template handlers (#23089)

* common : delegate assistant continuation to template handler

* server : implement echo parameter to exclude assistant prefill in the response

* server : fix tests for prefill

* server : use existing llama template

* cont : clean up
This commit is contained in:
Aldehir Rojas
2026-05-17 07:36:05 -04:00
committed by GitHub
parent a6d6183dbc
commit 39cf5d6191
10 changed files with 1112 additions and 191 deletions
+22 -85
View File
@@ -1032,23 +1032,33 @@ json oaicompat_chat_params_parse(
auto caps = common_chat_templates_get_caps(opt.tmpls.get());
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]);
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
const bool continue_final_message = json_value(body, "continue_final_message", false);
if (continue_final_message && inputs.add_generation_prompt) {
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", caps["supports_parallel_tool_calls"]);
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.continue_final_message = body.contains("continue_final_message") ?
common_chat_continuation_parse(body.at("continue_final_message")) :
COMMON_CHAT_CONTINUATION_NONE;
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_NONE && opt.prefill_assistant
&& !inputs.messages.empty() && inputs.messages.back().role == "assistant") {
if (inputs.messages.size() >= 2 && inputs.messages[inputs.messages.size() - 2].role == "assistant") {
throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list.");
}
inputs.continue_final_message = COMMON_CHAT_CONTINUATION_AUTO;
inputs.add_generation_prompt = false;
}
if (inputs.continue_final_message != COMMON_CHAT_CONTINUATION_NONE && inputs.add_generation_prompt) {
throw std::invalid_argument("Cannot set both add_generation_prompt and continue_final_message to true.");
}
inputs.reasoning_format = opt.reasoning_format;
inputs.reasoning_format = opt.reasoning_format;
if (body.contains("reasoning_format")) {
inputs.reasoning_format = common_reasoning_format_from_name(body.at("reasoning_format").get<std::string>());
}
inputs.enable_thinking = opt.enable_thinking;
inputs.enable_thinking = opt.enable_thinking;
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
if (body.contains("grammar")) {
throw std::invalid_argument("Cannot use custom grammar constraints with tools.");
@@ -1073,84 +1083,11 @@ json oaicompat_chat_params_parse(
throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)");
}
// if the assistant message appears at the end of list, we do not add end-of-turn token
// for ex. this can be useful to modify the reasoning process in reasoning models
// continue_final_message is the explicit opt in alias from the vLLM/transformers API,
// equivalent to the prefill_assistant heuristic
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant"
&& (continue_final_message || opt.prefill_assistant);
common_chat_msg last_message;
if (prefill_assistant_message) {
last_message = inputs.messages.back();
inputs.messages.pop_back();
/* sanity check, max one assistant message at the end of the list */
if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list.");
}
// reject reasoning prefill on channel based templates that do not expose explicit thinking tags
if (!last_message.reasoning_content.empty() && inputs.enable_thinking) {
auto probe_params = common_chat_templates_apply(opt.tmpls.get(), inputs);
if (probe_params.supports_thinking && probe_params.thinking_end_tag.empty()) {
throw std::invalid_argument("Assistant prefill with reasoning_content is not supported yet for this template.");
}
}
inputs.add_generation_prompt = true;
}
inputs.force_pure_content = opt.force_pure_content;
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs);
/* Append assistant prefilled message */
if (prefill_assistant_message) {
const bool thinking_active = chat_params.supports_thinking && !chat_params.thinking_end_tag.empty();
const bool has_reasoning = !last_message.reasoning_content.empty();
const bool has_content = !last_message.content.empty() || !last_message.content_parts.empty();
const bool mid_reasoning = has_reasoning && !has_content;
// some templates inject thinking_start in generation_prompt, others let the model emit it
const bool gp_has_think = thinking_active
&& chat_params.generation_prompt.find(chat_params.thinking_start_tag) != std::string::npos;
// open the thinking block when reasoning is present and the template did not inject it
if (has_reasoning) {
if (thinking_active && !gp_has_think) {
chat_params.prompt += chat_params.thinking_start_tag;
}
chat_params.prompt += last_message.reasoning_content;
}
if (thinking_active) {
if (mid_reasoning) {
// model continues inside the thinking block, keep generation_prompt open on think
if (!gp_has_think) {
chat_params.generation_prompt += chat_params.thinking_start_tag;
}
} else {
// close thinking block when reasoning is followed by content, or when the template forced it open
if (has_reasoning || gp_has_think) {
chat_params.prompt += chat_params.thinking_end_tag;
}
// strip thinking_start from generation_prompt so the parser routes model output as content
auto pos = chat_params.generation_prompt.rfind(chat_params.thinking_start_tag);
if (pos != std::string::npos) {
chat_params.generation_prompt = chat_params.generation_prompt.substr(0, pos);
}
}
}
if (!last_message.content_parts.empty()) {
for (auto & p : last_message.content_parts) {
chat_params.prompt += p.text;
}
} else {
chat_params.prompt += last_message.content;
}
}
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
if (!chat_params.grammar.empty()) {
+12
View File
@@ -144,6 +144,17 @@ json task_params::to_json(bool only_metrics) const {
//
// task_result_state
//
task_result_state::task_result_state(const common_chat_parser_params & chat_parser_params)
: chat_parser_params(chat_parser_params)
, oai_resp_id("resp_" + random_string())
, oai_resp_reasoning_id("rs_" + random_string())
, oai_resp_message_id("msg_" + random_string()) {
if (!chat_parser_params.echo) {
// initialize chat_msg to avoid emitting a delta containing the assistant prefill
chat_msg = common_chat_parse("", true, chat_parser_params);
}
}
common_chat_msg task_result_state::update_chat_msg(
const std::string & text_added,
bool is_partial,
@@ -421,6 +432,7 @@ task_params server_task::params_from_json_cmpl(
if (data.contains("chat_parser")) {
params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
}
params.chat_parser_params.echo = json_value(data, "echo", false);
}
{
+1 -5
View File
@@ -112,11 +112,7 @@ struct task_result_state {
const std::string oai_resp_message_id;
std::string oai_resp_fc_id; // function call ID for current args delta
task_result_state(const common_chat_parser_params & chat_parser_params)
: chat_parser_params(chat_parser_params)
, oai_resp_id("resp_" + random_string())
, oai_resp_reasoning_id("rs_" + random_string())
, oai_resp_message_id("msg_" + random_string()) {}
task_result_state(const common_chat_parser_params & chat_parser_params);
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
@@ -158,11 +158,12 @@ def test_chat_template():
@pytest.mark.parametrize("prefill,re_prefill", [
("Whill", "Whill"),
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Wh\n\nill"),
])
def test_chat_template_assistant_prefill(prefill, re_prefill):
global server
server.chat_template = "llama3"
server.jinja = True
server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
@@ -175,14 +176,15 @@ def test_chat_template_assistant_prefill(prefill, re_prefill):
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
assert res.body["__verbose"]["prompt"].endswith(f"<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}")
def test_chat_template_continue_final_message_vllm_compat():
"""continue_final_message is the vLLM/transformers explicit alias for the prefill_assistant heuristic.
Both must produce the same prompt."""
global server
server.chat_template = "llama3"
server.jinja = True
server.chat_template_file = "../../../models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"
server.debug = True
server.start()
res = server.make_request("POST", "/chat/completions", data={
@@ -197,7 +199,7 @@ def test_chat_template_continue_final_message_vllm_compat():
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill"
assert res.body["__verbose"]["prompt"].endswith("<|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nWhill")
def test_chat_template_continue_final_message_mutual_exclusion():