spec : parallel drafting support (#22838)
* spec : refactor * spec : drop support for incompatible vocabs * spec : update common_speculative_init() * cont : pass seq_id * cont : dedup ctx_seq_rm_type * server : sketch the ctx_dft decode loop * server : draft prompt cache and checkpoints * server : improve ctx names * server, spec : transition to unified spec context * cont : sync main and drft contexts * cont : async drft eval when possible * cont : handle non-ckpt models * cont : pass correct n_past for drafting * cont : process images throught the draft context * spec : handle draft running out of context * server : fix mtmd draft processing * server : fix URL for draft model * server : add comment * server : clean-up + dry * speculative-simple : update * spec : fix n_past type * server : fix slot ctx_drft ptr * tools : update readme * naming : improve consistency * spec : refactor for multi-sequence speculative context * cont : prepare params * cont : prepare params * spec : support parallel drafts * server : support parallel drafting * llama : reuse device buffers when possible * server, spec : clean-up * cont : clean-up * cont : minor * spec : reset `drafting` flag at the end * spec : introduce `common_speculative_process()` * spec : allow for multiple spec types (chain of speculators) * replace old type field of type common_speculative_type in the common_params_speculative struct with a vector to allow multiple types to be specified * introduce common_get_enabled_speculative_impls(const std::vector<enum common_speculative_type>) to figure out which implementations the user has enabled * introduce common_speculative_type_from_names(const std::vector<std::string> & names) to parse the already user provided spec types * all speculators run sequentially, best one wins (we verify its drafted tokens) * maximize expected accepted tokens for current round by calculating the product between the probability of accepting current token (n_acc_tokens / n_gen_drafts) and the draft's length --------- Co-authored-by: Petros Sideris <petros.sideris@nokia.com>
This commit is contained in:
@@ -195,11 +195,9 @@
|
||||
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
|
||||
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
|
||||
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
|
||||
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
|
||||
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
|
||||
@@ -244,11 +244,9 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--spec-draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_SPEC_DRAFT_N_MIN) |
|
||||
| `--spec-draft-p-split, --draft-p-split P` | speculative decoding split probability (default: 0.10)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_SPLIT) |
|
||||
| `--spec-draft-p-min, --draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.75)<br/>(env: LLAMA_ARG_SPEC_DRAFT_P_MIN) |
|
||||
| `--spec-draft-ctx-size, -cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CTX_SIZE) |
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-draft-replace, --spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
|
||||
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
|
||||
+356
-236
File diff suppressed because it is too large
Load Diff
@@ -76,7 +76,7 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.types", common_speculative_type_name_str(speculative.types)},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -133,7 +133,7 @@ json task_params::to_json(bool only_metrics) const {
|
||||
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
|
||||
{"generation_prompt", chat_parser_params.generation_prompt},
|
||||
{"samplers", samplers},
|
||||
{"speculative.type", common_speculative_type_to_str(speculative.type)},
|
||||
{"speculative.types", common_speculative_type_name_str(speculative.types)},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
@@ -296,6 +296,8 @@ task_params server_task::params_from_json_cmpl(
|
||||
|
||||
params.speculative = defaults.speculative;
|
||||
|
||||
// TODO: to keep things simple, we disable speculative parameter adjustments for now
|
||||
#if 0
|
||||
// TODO: for now, be able to adjust only the draft-model based speculative parameters
|
||||
params.speculative.draft.n_min = json_value(data, "speculative.n_min", defaults.speculative.draft.n_min);
|
||||
params.speculative.draft.n_max = json_value(data, "speculative.n_max", defaults.speculative.draft.n_max);
|
||||
@@ -305,7 +307,6 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.speculative.draft.n_min = std::max(params.speculative.draft.n_min, 0);
|
||||
params.speculative.draft.n_max = std::max(params.speculative.draft.n_max, 0);
|
||||
|
||||
#if 0
|
||||
// for debugging and research purposes
|
||||
params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type)));
|
||||
|
||||
@@ -1981,7 +1982,7 @@ size_t server_prompt_cache::n_tokens() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) {
|
||||
server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size_tgt, size_t state_size_dft) {
|
||||
// first check if the current state is contained fully in the cache
|
||||
for (auto it = states.begin(); it != states.end(); ++it) {
|
||||
const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens);
|
||||
@@ -2005,11 +2006,13 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data;
|
||||
std::vector<uint8_t> state_data_tgt;
|
||||
std::vector<uint8_t> state_data_dft;
|
||||
|
||||
// check if we can allocate enough memory for the new state
|
||||
try {
|
||||
state_data.resize(state_size);
|
||||
state_data_tgt.resize(state_size_tgt);
|
||||
state_data_dft.resize(state_size_dft);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what());
|
||||
|
||||
@@ -2022,17 +2025,19 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto & cur = states.emplace_back();
|
||||
cur = {
|
||||
states.push_back({
|
||||
/*.tokens =*/ prompt.tokens.clone(),
|
||||
/*.data =*/ std::move(state_data),
|
||||
/*.data =*/ {
|
||||
/*.main =*/ std::move(state_data_tgt),
|
||||
/*.drft =*/ std::move(state_data_dft),
|
||||
},
|
||||
/*.checkpoints =*/ prompt.checkpoints,
|
||||
};
|
||||
});
|
||||
|
||||
return &cur;
|
||||
return &states.back();
|
||||
}
|
||||
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) {
|
||||
bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_tgt, llama_context * ctx_dft, int32_t id_slot) {
|
||||
const int lcp_best = prompt.tokens.get_common_prefix(tokens_new);
|
||||
|
||||
float f_keep_best = prompt.tokens.size() > 0 ? float(lcp_best) / prompt.tokens.size() : -1.0f; // empty slot: any cache entry wins
|
||||
@@ -2065,16 +2070,39 @@ bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tok
|
||||
if (it_best != states.end()) {
|
||||
SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best);
|
||||
|
||||
const size_t size = it_best->data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
{
|
||||
auto & data = it_best->data.main;
|
||||
|
||||
return false;
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_tgt, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
|
||||
it_best->data.clear();
|
||||
it_best->data.shrink_to_fit();
|
||||
{
|
||||
auto & data = it_best->data.drft;
|
||||
|
||||
if (!data.empty()) {
|
||||
GGML_ASSERT(ctx_dft);
|
||||
|
||||
const size_t size = data.size();
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft, data.data(), size, id_slot, 0);
|
||||
if (n != size) {
|
||||
SRV_WRN("failed to restore state with size %zu\n", size);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
data.clear();
|
||||
data.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
prompt = std::move(*it_best);
|
||||
|
||||
|
||||
+14
-27
@@ -565,42 +565,29 @@ struct server_task_result_apply_lora : server_task_result {
|
||||
virtual json to_json() override;
|
||||
};
|
||||
|
||||
struct server_prompt_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
int64_t n_tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
struct server_prompt_data {
|
||||
std::vector<uint8_t> main;
|
||||
std::vector<uint8_t> drft;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return data.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pos_min = 0;
|
||||
pos_max = 0;
|
||||
n_tokens = 0;
|
||||
data.clear();
|
||||
return main.size() + drft.size();
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt {
|
||||
server_tokens tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
server_prompt_data data;
|
||||
|
||||
std::list<server_prompt_checkpoint> checkpoints;
|
||||
std::list<common_prompt_checkpoint> checkpoints;
|
||||
|
||||
size_t size() const {
|
||||
size_t res = data.size();
|
||||
size_t res = 0;
|
||||
|
||||
for (const auto & checkpoint : checkpoints) {
|
||||
res += checkpoint.size();
|
||||
res += data.size();
|
||||
|
||||
for (const auto & ckpt : checkpoints) {
|
||||
res += ckpt.size();
|
||||
}
|
||||
|
||||
return res;
|
||||
@@ -614,7 +601,7 @@ struct server_prompt {
|
||||
return server_prompt {
|
||||
tokens.clone(),
|
||||
data,
|
||||
checkpoints
|
||||
checkpoints,
|
||||
};
|
||||
}
|
||||
};
|
||||
@@ -637,9 +624,9 @@ struct server_prompt_cache {
|
||||
|
||||
size_t n_tokens() const;
|
||||
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
|
||||
server_prompt * alloc(const server_prompt & prompt, size_t state_size_main, size_t state_size_drft);
|
||||
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
|
||||
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx_main, llama_context * ctx_drft, int32_t id_slot);
|
||||
|
||||
void update();
|
||||
};
|
||||
|
||||
@@ -5,7 +5,7 @@ from utils import *
|
||||
|
||||
server = ServerPreset.stories15m_moe()
|
||||
|
||||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
|
||||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/tiny-llamas/resolve/main/stories15M-q4_0.gguf"
|
||||
|
||||
def create_server():
|
||||
global server
|
||||
|
||||
Reference in New Issue
Block a user