llama + spec: MTP Support (#22673)
* spec: support MTP * fix batch size * rename files * cont : simplify (#7) * MTP: clean-up (#9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion * mtp -> draft-mtp * remove unused llama_arch * add need_embd in speculative * llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. * fix pending state * vulkan: add GDN partial rollback * meta: extend check to axis 1 * metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi * delta_net_base: use ggml_pad instead of new_tensor * review: add need_rs_seq * review: rename part_bounded to n_rs * review: deslop comments * review: rename, add asserts * server : adjust checkpoint logic (#11) * server : adjust checkpoint logic * cont : rm asserts * server-context: fix early exit * spec : fix compatibility with n-gram and add TODOs (#13) * metal : cleanup * llama : fix faulty bitwise check in recurrent memory * server : disable RS-based MTP in combination with other spec types * spec : add TODOs * cont : fix comment * cont : update comment * common : fix logic for ngram + mtp compat * llama-memory: enable checkpointing with partial rollback * cont: add test-case for loading into a dirty ctx * llama-memory-recurrent: clear rs_idx in clear * download: fix mtp path * llama-arch: fix enorm op * docs: update docs * conversion: fix type annotations --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
+28
-6
@@ -337,11 +337,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
|
||||
struct handle_model_result {
|
||||
bool found_mmproj = false;
|
||||
common_params_model mmproj;
|
||||
|
||||
bool found_mtp = false;
|
||||
common_params_model mtp;
|
||||
};
|
||||
|
||||
static handle_model_result common_params_handle_model(struct common_params_model & model,
|
||||
const std::string & bearer_token,
|
||||
bool offline) {
|
||||
bool offline,
|
||||
bool search_mtp = false) {
|
||||
handle_model_result result;
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
@@ -356,7 +360,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = bearer_token;
|
||||
opts.offline = offline;
|
||||
auto download_result = common_download_model(model, opts, true);
|
||||
auto download_result = common_download_model(model, opts, true, search_mtp);
|
||||
|
||||
if (download_result.model_path.empty()) {
|
||||
throw std::runtime_error("failed to download model from Hugging Face");
|
||||
@@ -369,6 +373,11 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
result.found_mmproj = true;
|
||||
result.mmproj.path = download_result.mmproj_path;
|
||||
}
|
||||
|
||||
if (!download_result.mtp_path.empty()) {
|
||||
result.found_mtp = true;
|
||||
result.mtp.path = download_result.mtp_path;
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
if (model.path.empty()) {
|
||||
auto f = string_split<std::string>(model.url, '#').front();
|
||||
@@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) {
|
||||
//
|
||||
|
||||
void common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
|
||||
auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp);
|
||||
if (params.no_mmproj) {
|
||||
params.mmproj = {};
|
||||
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
|
||||
@@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
break;
|
||||
}
|
||||
}
|
||||
// when --spec-type mtp is set and no draft model was provided explicitly,
|
||||
// fall back to the MTP head discovered alongside the -hf model
|
||||
if (spec_type_draft_mtp && res.found_mtp &&
|
||||
params.speculative.draft.mparams.path.empty() &&
|
||||
params.speculative.draft.mparams.hf_repo.empty() &&
|
||||
params.speculative.draft.mparams.url.empty()) {
|
||||
params.speculative.draft.mparams.path = res.mtp.path;
|
||||
}
|
||||
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
}
|
||||
@@ -3608,8 +3629,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("comma-separated list of types of speculative decoding to use (default: %s)\n",
|
||||
common_speculative_type_name_str(params.speculative.types).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
const auto enabled_types = string_split<std::string>(value, ',');
|
||||
params.speculative.types = common_speculative_types_from_names(enabled_types);
|
||||
const auto types_str = string_split<std::string>(value, ',');
|
||||
auto types = common_speculative_types_from_names(types_str);
|
||||
params.speculative.types.insert(params.speculative.types.end(), types.begin(), types.end());
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE"));
|
||||
add_opt(common_arg(
|
||||
@@ -4098,7 +4120,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--spec-default"},
|
||||
string_format("enable default speculative decoding config"),
|
||||
[](common_params & params) {
|
||||
params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD };
|
||||
params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
|
||||
params.speculative.ngram_mod.n_match = 24;
|
||||
params.speculative.ngram_mod.n_min = 48;
|
||||
params.speculative.ngram_mod.n_max = 64;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -1247,6 +1248,29 @@ common_init_result::common_init_result(common_params & params) :
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
}
|
||||
|
||||
// [TAG_RS_STATE_ROLLBACK_SUPPORT]
|
||||
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
|
||||
// currently this is not supported. so we disable the partial rollback
|
||||
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
|
||||
auto & types = params.speculative.types;
|
||||
|
||||
for (int i = 0; i < (int) types.size(); i++) {
|
||||
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
continue;
|
||||
}
|
||||
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cparams.n_rs_seq = 0;
|
||||
|
||||
LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
|
||||
common_speculative_type_to_str(types[i]).c_str());
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
@@ -1435,6 +1459,12 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (llama_n_rs_seq(ctx) > 0) {
|
||||
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
|
||||
goto done;
|
||||
}
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
|
||||
@@ -1449,6 +1479,23 @@ done:
|
||||
return res;
|
||||
}
|
||||
|
||||
void common_context_seq_rm(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
if (!llama_memory_seq_rm(mem, seq_id, p0, p1)) {
|
||||
GGML_ABORT("%s", string_format("failed to remove sequence %d with p0=%d, p1=%d\n", seq_id, p0, p1).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void common_context_seq_cp(llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
llama_memory_seq_cp(mem, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
llama_memory_seq_add(mem, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||
std::vector<llama_adapter_lora *> loras;
|
||||
std::vector<float> scales;
|
||||
@@ -1505,6 +1552,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
|
||||
cparams.n_ctx = params.n_ctx;
|
||||
cparams.n_seq_max = params.n_parallel;
|
||||
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_ubatch = params.n_ubatch;
|
||||
cparams.n_threads = params.cpuparams.n_threads;
|
||||
@@ -2074,3 +2122,11 @@ void common_prompt_checkpoint::load_dft(
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
|
||||
}
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::clear_tgt() {
|
||||
data_tgt.clear();
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::clear_dft() {
|
||||
data_dft.clear();
|
||||
}
|
||||
|
||||
+22
-4
@@ -13,6 +13,7 @@
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||
#define _WIN32_WINNT 0x0A00
|
||||
@@ -159,6 +160,7 @@ enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
@@ -301,7 +303,7 @@ struct common_params_speculative_draft {
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f
|
||||
|
||||
common_params_model mparams;
|
||||
|
||||
@@ -355,6 +357,14 @@ struct common_params_speculative {
|
||||
bool has_dft() const {
|
||||
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
|
||||
}
|
||||
|
||||
uint32_t need_n_rs_seq() const {
|
||||
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
|
||||
});
|
||||
|
||||
return needs_rs_seq ? draft.n_max : 0u;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
@@ -884,15 +894,20 @@ std::string common_get_model_endpoint();
|
||||
//
|
||||
|
||||
enum common_context_seq_rm_type {
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
|
||||
COMMON_CONTEXT_SEQ_RM_TYPE_RS = 3, // can seq_rm partial sequences, bounded by n_rs_seq
|
||||
};
|
||||
|
||||
// check if the llama_context can remove sequences
|
||||
// note: clears the memory of the context
|
||||
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);
|
||||
|
||||
// aborts execution on failure
|
||||
void common_context_seq_rm (llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
|
||||
void common_context_seq_cp (llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
|
||||
|
||||
//
|
||||
// Batch utils
|
||||
@@ -1074,4 +1089,7 @@ struct common_prompt_checkpoint {
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_state_seq_flags flags) const;
|
||||
|
||||
void clear_tgt();
|
||||
void clear_dft();
|
||||
};
|
||||
|
||||
+42
-13
@@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
|
||||
return result;
|
||||
}
|
||||
|
||||
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
|
||||
const std::string & model) {
|
||||
// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"),
|
||||
// preferring deeper shared directory prefix with the model, then closest quantization
|
||||
static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files,
|
||||
const std::string & model,
|
||||
const std::string & keyword) {
|
||||
hf_cache::hf_file best;
|
||||
size_t best_depth = 0;
|
||||
int best_diff = 0;
|
||||
@@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
|
||||
|
||||
for (const auto & f : files) {
|
||||
if (!string_ends_with(f.path, ".gguf") ||
|
||||
f.path.find("mmproj") == std::string::npos) {
|
||||
f.path.find(keyword) == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto mmproj_parts = string_split<std::string>(f.path, '/');
|
||||
auto mmproj_dir = mmproj_parts.end() - 1;
|
||||
auto sib_parts = string_split<std::string>(f.path, '/');
|
||||
auto sib_dir = sib_parts.end() - 1;
|
||||
|
||||
auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
|
||||
mmproj_parts.begin(), mmproj_dir);
|
||||
if (dir != mmproj_dir) {
|
||||
sib_parts.begin(), sib_dir);
|
||||
if (dir != sib_dir) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t depth = dir - mmproj_parts.begin();
|
||||
size_t depth = dir - sib_parts.begin();
|
||||
auto bits = extract_quant_bits(f.path);
|
||||
auto diff = std::abs(bits - model_bits);
|
||||
|
||||
@@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
|
||||
return best;
|
||||
}
|
||||
|
||||
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
|
||||
const std::string & model) {
|
||||
return find_best_sibling(files, model, "mmproj");
|
||||
}
|
||||
|
||||
static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files,
|
||||
const std::string & model) {
|
||||
return find_best_sibling(files, model, "mtp-");
|
||||
}
|
||||
|
||||
static bool gguf_filename_is_model(const std::string & filepath) {
|
||||
if (!string_ends_with(filepath, ".gguf")) {
|
||||
return false;
|
||||
@@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) {
|
||||
}
|
||||
|
||||
return filename.find("mmproj") == std::string::npos &&
|
||||
filename.find("imatrix") == std::string::npos;
|
||||
filename.find("imatrix") == std::string::npos &&
|
||||
filename.find("mtp-") == std::string::npos;
|
||||
}
|
||||
|
||||
static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
|
||||
@@ -673,11 +687,13 @@ struct hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
hf_cache::hf_file mtp;
|
||||
};
|
||||
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj) {
|
||||
bool download_mmproj,
|
||||
bool download_mtp) {
|
||||
hf_plan plan;
|
||||
hf_cache::hf_files all;
|
||||
|
||||
@@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
plan.mmproj = find_best_mmproj(all, primary.path);
|
||||
}
|
||||
|
||||
if (download_mtp) {
|
||||
plan.mtp = find_best_mtp(all, primary.path);
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
@@ -756,7 +776,8 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
|
||||
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj) {
|
||||
bool download_mmproj,
|
||||
bool download_mtp) {
|
||||
common_download_model_result result;
|
||||
std::vector<download_task> tasks;
|
||||
hf_plan hf;
|
||||
@@ -764,13 +785,16 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
hf = get_hf_plan(model, opts, download_mmproj);
|
||||
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
|
||||
}
|
||||
if (!hf.mtp.path.empty()) {
|
||||
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
tasks = get_url_tasks(model);
|
||||
} else {
|
||||
@@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
}
|
||||
|
||||
if (!hf.mtp.path.empty()) {
|
||||
result.mtp_path = hf_cache::finalize_file(hf.mtp);
|
||||
}
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
}
|
||||
@@ -946,7 +974,8 @@ std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
for (const auto & f : files) {
|
||||
auto split = get_gguf_split_info(f.path);
|
||||
if (split.index != 1 || split.tag.empty() ||
|
||||
split.prefix.find("mmproj") != std::string::npos) {
|
||||
split.prefix.find("mmproj") != std::string::npos ||
|
||||
split.prefix.find("mtp-") != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (seen.insert(f.repo_id + ":" + split.tag).second) {
|
||||
|
||||
+5
-2
@@ -59,6 +59,7 @@ struct common_download_opts {
|
||||
struct common_download_model_result {
|
||||
std::string model_path;
|
||||
std::string mmproj_path;
|
||||
std::string mtp_path;
|
||||
};
|
||||
|
||||
// Download model from HuggingFace repo or URL
|
||||
@@ -83,12 +84,14 @@ struct common_download_model_result {
|
||||
// when opts.offline=true, no network requests are made
|
||||
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
|
||||
// then with the closest quantization bits
|
||||
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
|
||||
//
|
||||
// returns result with model_path and mmproj_path (empty on failure)
|
||||
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
|
||||
common_download_model_result common_download_model(
|
||||
const common_params_model & model,
|
||||
const common_download_opts & opts = {},
|
||||
bool download_mmproj = false
|
||||
bool download_mmproj = false,
|
||||
bool download_mtp = false
|
||||
);
|
||||
|
||||
// returns list of cached models
|
||||
|
||||
+377
-8
@@ -3,6 +3,7 @@
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
|
||||
#include "log.h"
|
||||
#include "ngram-cache.h"
|
||||
#include "ngram-map.h"
|
||||
@@ -23,6 +24,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
|
||||
{"none", COMMON_SPECULATIVE_TYPE_NONE},
|
||||
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
|
||||
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
|
||||
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
|
||||
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
||||
@@ -143,6 +145,9 @@ struct common_speculative_impl {
|
||||
virtual void draft(common_speculative_draft_params_vec & dparams) = 0;
|
||||
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract embeddings
|
||||
virtual bool need_embd() const = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
@@ -338,6 +343,10 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
@@ -362,6 +371,328 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
std::vector<common_sampler_ptr> smpls;
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
// call to pair with, so it's stashed here until that next call fires.
|
||||
std::vector<std::vector<float>> pending_h; // [n_seq][n_embd]
|
||||
|
||||
std::vector<int32_t> i_batch_beg;
|
||||
std::vector<int32_t> i_batch_end;
|
||||
|
||||
// Hidden rows from the most recent target verification batch, grouped by seq.
|
||||
// Row 0 corresponds to the sampled token, row N to the Nth accepted draft token.
|
||||
std::vector<std::vector<float>> verify_h;
|
||||
std::vector<int32_t> verify_h_rows;
|
||||
|
||||
// Per-seq draft length from the last draft() call, used in accept() to
|
||||
// roll back ctx_dft's recurrent state past the AR draft's redundant
|
||||
// pre-advancement before process() mirrored the verify batch.
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
|
||||
common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
|
||||
|
||||
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
|
||||
|
||||
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
|
||||
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
|
||||
// llama_batch_init allocates only one of token/embd; MTP needs both.
|
||||
// TODO: fix, how to call without malloc
|
||||
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b);
|
||||
|
||||
smpls.resize(n_seq);
|
||||
for (auto & s : smpls) {
|
||||
common_params_sampling sparams;
|
||||
sparams.no_perf = false;
|
||||
sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param
|
||||
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
|
||||
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
|
||||
}
|
||||
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true);
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
i_batch_beg.assign(n_seq, -1);
|
||||
i_batch_end.assign(n_seq, -1);
|
||||
|
||||
verify_h.assign(n_seq, {});
|
||||
verify_h_rows.assign(n_seq, 0);
|
||||
|
||||
last_n_drafted.assign(n_seq, 0);
|
||||
}
|
||||
|
||||
~common_speculative_state_draft_mtp() override {
|
||||
if (batch.token != nullptr) {
|
||||
free(batch.token);
|
||||
batch.token = nullptr;
|
||||
}
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
|
||||
const int32_t N = (int32_t) prompt.size();
|
||||
if (N <= 0) {
|
||||
return;
|
||||
}
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 1) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 1);
|
||||
}
|
||||
}
|
||||
|
||||
bool process(const llama_batch & batch_in) override {
|
||||
if (batch_in.n_tokens <= 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO: how to make it work with vision tokens?
|
||||
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t n_tokens = batch_in.n_tokens;
|
||||
|
||||
// remember the frist and last batch index for each sequence
|
||||
std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1);
|
||||
std::fill(i_batch_end.begin(), i_batch_end.end(), -1);
|
||||
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
|
||||
|
||||
if (batch_in.seq_id[k][0] == seq_id) {
|
||||
i_batch_end[seq_id] = k;
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
i_batch_beg[seq_id] = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (int k = 0; k < n_tokens; ++k) {
|
||||
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
|
||||
}
|
||||
|
||||
// shift the tgt embeddings to the right by one position
|
||||
// assumes that the tokens in the batch are sequential for each sequence
|
||||
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
|
||||
// ^--- this is a problem
|
||||
// TODO:this is generally true, but would be nice to assert it
|
||||
{
|
||||
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
|
||||
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
|
||||
|
||||
//{
|
||||
// // string with seq_ids in the batch
|
||||
// std::stringstream ss;
|
||||
// for (int i = 0; i < n_tokens; ++i) {
|
||||
// ss << batch_in.seq_id[i][0] << ",";
|
||||
// }
|
||||
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
|
||||
//}
|
||||
}
|
||||
|
||||
// fill the pending embeddings from a previous run
|
||||
auto set_h = [&](int idx, const float * h_row) {
|
||||
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
|
||||
};
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_end[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
|
||||
verify_h_rows[seq_id] = n_rows;
|
||||
verify_h[seq_id].resize((size_t) n_rows * n_embd);
|
||||
|
||||
for (int32_t i = 0; i < n_rows; ++i) {
|
||||
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i);
|
||||
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
|
||||
}
|
||||
|
||||
std::memcpy(pending_h[seq_id].data(),
|
||||
verify_h[seq_id].data() + (size_t) (n_rows - 1) * n_embd, row_bytes);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void draft(common_speculative_draft_params_vec & dparams) override {
|
||||
auto & ctx_dft = params.ctx_dft;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
// keep track of which sequences are still drafting
|
||||
int n_drafting = 0;
|
||||
std::vector<bool> drafting(n_seq);
|
||||
|
||||
const float * h_row = nullptr;
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
|
||||
if (!dp.drafting) {
|
||||
continue;
|
||||
}
|
||||
|
||||
n_drafting++;
|
||||
drafting[seq_id] = true;
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
|
||||
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
|
||||
while (n_drafting > 0) {
|
||||
int i_batch = 0;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (!drafting[seq_id]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, i_batch, true);
|
||||
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
|
||||
++i_batch;
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
auto & dp = dparams.at(seq_id);
|
||||
auto & result = *dp.result;
|
||||
|
||||
result.push_back(id);
|
||||
|
||||
if (params.n_max <= (int) result.size()) {
|
||||
drafting[seq_id] = false;
|
||||
n_drafting--;
|
||||
continue;
|
||||
}
|
||||
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dp.result->size() < (size_t) params.n_min) {
|
||||
dp.result->clear();
|
||||
}
|
||||
|
||||
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t n_rows = verify_h_rows[seq_id];
|
||||
if (n_rows <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t i_h = std::min<int32_t>(n_accepted, n_rows - 1);
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) i_h * n_embd, row_bytes);
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// state of self-speculation (simple implementation, not ngram-map)
|
||||
@@ -403,6 +734,10 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
@@ -451,6 +786,10 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
|
||||
common_ngram_map_accept(config[seq_id], n_accepted);
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
@@ -619,6 +958,10 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
@@ -752,6 +1095,10 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative {
|
||||
@@ -820,6 +1167,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
|
||||
@@ -875,8 +1223,8 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
bool has_draft_model_path = !params.draft.mparams.path.empty();
|
||||
|
||||
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
|
||||
// bool has_mtp = false; // TODO: add MTP here
|
||||
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
|
||||
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
|
||||
bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE));
|
||||
bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE));
|
||||
@@ -885,7 +1233,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
|
||||
|
||||
// when adding a new type - update here the logic above
|
||||
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8);
|
||||
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
|
||||
|
||||
// this list here defines the priority of the speculators
|
||||
// the one with highest priority are listed first
|
||||
@@ -911,7 +1259,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
|
||||
has_draft_simple = false;
|
||||
}
|
||||
} else if (has_draft_model_path) {
|
||||
} else if (has_draft_model_path && !has_mtp && !has_draft_eagle3) {
|
||||
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
|
||||
has_draft_simple = true;
|
||||
}
|
||||
@@ -919,10 +1267,12 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
if (has_draft_simple) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params));
|
||||
}
|
||||
// TODO: add MTP here
|
||||
if (has_draft_eagle3) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
|
||||
@@ -940,6 +1290,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
impls.push_back(std::make_unique<common_speculative_impl_draft_eagle3>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft_mtp>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
|
||||
|
||||
@@ -1040,6 +1394,20 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b
|
||||
return result;
|
||||
}
|
||||
|
||||
bool common_speculative_need_embd(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->need_embd()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_speculative_draft(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
@@ -1122,14 +1490,15 @@ void common_speculative_draft(common_speculative * spec) {
|
||||
}
|
||||
|
||||
void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) {
|
||||
if (n_accepted == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
common_speculative_impl * impl = spec->impl_last[seq_id];
|
||||
|
||||
GGML_ASSERT(impl);
|
||||
|
||||
// TODO: currently only the implementation that generated the draft is used to accept it
|
||||
// however, some implementations (such as MTP) need to also "see" the accepted tokens
|
||||
// extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to
|
||||
// inform the implementation if the accepted tokens are from another implementation and
|
||||
// pass the accepted tokens to all remaining implementations using `is_other == true`
|
||||
{
|
||||
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
|
||||
if (n_accepted > 0) {
|
||||
|
||||
@@ -53,6 +53,9 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
|
||||
// process the batch and update the internal state of the speculative context
|
||||
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
||||
|
||||
// true if any implementation requires target embeddings to be extracted
|
||||
bool common_speculative_need_embd(common_speculative * spec);
|
||||
|
||||
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
||||
void common_speculative_draft(common_speculative * spec);
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class ModelBase:
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
metadata: gguf.Metadata
|
||||
dir_model_card: Path
|
||||
remote_hf_model_id: str | None
|
||||
|
||||
@@ -106,6 +107,11 @@ class ModelBase:
|
||||
disable_mistral_community_chat_template: bool = False
|
||||
sentence_transformers_dense_modules: bool = False
|
||||
|
||||
# MTP (multi-token prediction) export modes; set by main() before instantiation.
|
||||
# Architectures opt in by overriding the handling (see _Qwen35MtpMixin).
|
||||
mtp_only: bool = False
|
||||
no_mtp: bool = False
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
|
||||
+86
-3
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Iterable, TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -534,11 +535,93 @@ class _Qwen35MRopeMixin:
|
||||
self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION)
|
||||
|
||||
|
||||
class _Qwen35MtpMixin:
|
||||
"""Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries
|
||||
the MTP block under `mtp_num_hidden_layers` and the tensors under
|
||||
`mtp.*`; we extend block_count, emit the nextn metadata key, and remap
|
||||
`mtp.*` to the standard layer-indexed nextn naming so the existing
|
||||
tensor_map handles them."""
|
||||
|
||||
hparams: dict[str, Any]
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
block_count: int
|
||||
tensor_map: gguf.TensorNameMap
|
||||
no_mtp: bool
|
||||
mtp_only: bool
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.block_count = self.hparams["num_hidden_layers"]
|
||||
if not self.no_mtp:
|
||||
self.block_count += self.hparams.get("mtp_num_hidden_layers", 0)
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item):
|
||||
name, _ = item
|
||||
if name.startswith("mtp."):
|
||||
if cls.no_mtp:
|
||||
return None
|
||||
return item
|
||||
if cls.mtp_only:
|
||||
canonical = name.replace("language_model.", "")
|
||||
keep = canonical in (
|
||||
"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight",
|
||||
"embed_tokens.weight", "norm.weight",
|
||||
)
|
||||
if not keep:
|
||||
return None
|
||||
return super().filter_tensors(item) # ty: ignore[unresolved-attribute]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
|
||||
if self.no_mtp:
|
||||
return
|
||||
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
|
||||
self.gguf_writer.add_nextn_predict_layers(n)
|
||||
|
||||
def prepare_metadata(self, vocab_only: bool):
|
||||
from_dir = self.fname_out.is_dir()
|
||||
super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute]
|
||||
|
||||
if not self.mtp_only or not from_dir:
|
||||
return
|
||||
|
||||
output_type: str = self.ftype.name.partition("_")[2] # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
fname_default: str = gguf.naming_convention(
|
||||
self.metadata.name, self.metadata.basename, self.metadata.finetune, # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
|
||||
self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf"
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.startswith("mtp."):
|
||||
n_layer = self.hparams["num_hidden_layers"]
|
||||
if name.find("layers.") != -1:
|
||||
assert bid is not None
|
||||
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
|
||||
else:
|
||||
remapper = {
|
||||
"mtp.fc": "model.layers.{bid}.eh_proj",
|
||||
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
|
||||
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
|
||||
"mtp.norm": "model.layers.{bid}.shared_head.norm",
|
||||
}
|
||||
stem = Path(name).stem
|
||||
suffix = Path(name).suffix
|
||||
tmpl = remapper[stem] + suffix
|
||||
for b in range(n_layer, self.block_count):
|
||||
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute]
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute]
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
|
||||
class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
|
||||
class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN35
|
||||
|
||||
|
||||
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
|
||||
class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
|
||||
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN35MOE
|
||||
|
||||
@@ -117,6 +117,14 @@ def parse_args() -> argparse.Namespace:
|
||||
"--mmproj", action="store_true",
|
||||
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mtp", action="store_true",
|
||||
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mtp", action="store_true",
|
||||
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mistral-format", action="store_true",
|
||||
help="Whether the model is stored following the Mistral format.",
|
||||
@@ -233,6 +241,20 @@ def main() -> None:
|
||||
from conversion.mistral import MistralModel
|
||||
model_class = MistralModel
|
||||
|
||||
if args.mtp and args.no_mtp:
|
||||
logger.error("--mtp and --no-mtp are mutually exclusive")
|
||||
sys.exit(1)
|
||||
|
||||
if args.mtp or args.no_mtp:
|
||||
from conversion.qwen import _Qwen35MtpMixin
|
||||
if not issubclass(model_class, _Qwen35MtpMixin):
|
||||
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today")
|
||||
sys.exit(1)
|
||||
if args.no_mtp:
|
||||
model_class.no_mtp = True
|
||||
if args.mtp:
|
||||
model_class.mtp_only = True
|
||||
|
||||
model_instance = model_class(dir_model, output_type, fname_out,
|
||||
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
|
||||
eager=args.no_lazy,
|
||||
|
||||
@@ -2541,6 +2541,11 @@ extern "C" {
|
||||
|
||||
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
|
||||
//
|
||||
// state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs):
|
||||
// K == 1: output carries the final state only.
|
||||
// K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K)
|
||||
// per-token snapshots into the trailing slots
|
||||
GGML_API struct ggml_tensor * ggml_gated_delta_net(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
|
||||
@@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co
|
||||
GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
||||
GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
||||
GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
||||
GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2);
|
||||
// state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0,
|
||||
// so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2).
|
||||
GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
|
||||
return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1};
|
||||
};
|
||||
|
||||
@@ -2140,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz
|
||||
const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
|
||||
return backend_ctx->backend_configs[index].backend;
|
||||
}
|
||||
|
||||
|
||||
@@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan(
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const int64_t S_v = node->src[2]->ne[0];
|
||||
cur = S_v * sizeof(float) * n_tasks;
|
||||
const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs)
|
||||
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
|
||||
cur = per_thread * sizeof(float) * n_tasks;
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
|
||||
@@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
||||
|
||||
const bool kda = (neg0 == S_v);
|
||||
|
||||
// scratch layout per thread: [delta(S_v)]
|
||||
const int64_t scratch_per_thread = S_v;
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int64_t K = src_state->ne[1];
|
||||
GGML_ASSERT(K >= 1);
|
||||
// per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride)
|
||||
const int64_t state_seq_stride = src_state->nb[2] / sizeof(float);
|
||||
|
||||
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
|
||||
const int ith = params->ith;
|
||||
|
||||
float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
|
||||
float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32;
|
||||
float * state_work = K > 1 ? (delta + S_v) : nullptr;
|
||||
|
||||
// output layout: [attn_scores | new_states]
|
||||
// attn_scores: S_v * H * n_tokens * n_seqs floats
|
||||
// new_states: S_v * S_v * H * n_seqs floats
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
// attn_scores: S_v * H * n_tokens * n_seqs floats
|
||||
// new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K))
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
const int64_t state_size_per_snap = S_v * S_v * H * n_seqs;
|
||||
float * attn_out_base = (float *)dst->data;
|
||||
float * state_out_base = (float *)dst->data + attn_score_elems;
|
||||
|
||||
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last
|
||||
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
|
||||
const int64_t shift = n_tokens - K;
|
||||
|
||||
const float * state_in_base = (const float *)src_state->data;
|
||||
|
||||
//const int64_t rq1 = nev1 / neq1;
|
||||
@@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
||||
const int64_t iq3 = iv3 / rq3;
|
||||
const int64_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
|
||||
// For K=1, write directly to the single output slot to avoid an extra memcpy at the end.
|
||||
// For K>1, work in scratch and copy out per-token when the slot is in range.
|
||||
float * s_out = (K > 1)
|
||||
? state_work
|
||||
: state_out_base + (iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
// copy input state into output buffer and operate in-place
|
||||
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
|
||||
// copy input state into the working buffer and operate in-place
|
||||
// state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride.
|
||||
const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v;
|
||||
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
|
||||
|
||||
// attn output pointer for first token of this (head, seq)
|
||||
@@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
||||
}
|
||||
|
||||
attn_data += S_v * H; // advance to next token
|
||||
|
||||
if (K > 1) {
|
||||
const int64_t target_slot = t - shift;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state_o = state_out_base + target_slot * state_size_per_snap +
|
||||
(iv3 * H + iv1) * S_v * S_v;
|
||||
memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "gated_delta_net.cuh"
|
||||
|
||||
template <int S_v, bool KDA>
|
||||
template <int S_v, bool KDA, bool keep_rs_t>
|
||||
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
|
||||
gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
@@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q,
|
||||
int64_t sb3,
|
||||
const uint3 neqk1_magic,
|
||||
const uint3 rq3_magic,
|
||||
float scale) {
|
||||
float scale,
|
||||
int K) {
|
||||
const uint32_t h_idx = blockIdx.x;
|
||||
const uint32_t sequence = blockIdx.y;
|
||||
// each warp owns one column, using warp-level primitives to reduce across rows
|
||||
@@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q,
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
state += state_offset;
|
||||
curr_state += state_offset + col * S_v;
|
||||
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
|
||||
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
|
||||
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
|
||||
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
|
||||
state += state_out_offset;
|
||||
curr_state += state_in_offset + col * S_v;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
|
||||
@@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q,
|
||||
s_shard[r] = curr_state[i];
|
||||
}
|
||||
|
||||
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
|
||||
// are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = (int) n_tokens - K;
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
@@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q,
|
||||
}
|
||||
|
||||
attn_data += S_v * H;
|
||||
|
||||
if constexpr (keep_rs_t) {
|
||||
const int target_slot = t - shift;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
curr_state[col * S_v + i] = s_shard[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write state back to global memory (transposed layout)
|
||||
if constexpr (!keep_rs_t) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
state[col * S_v + i] = s_shard[r];
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
state[col * S_v + i] = s_shard[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool KDA>
|
||||
template <bool KDA, bool keep_rs_t>
|
||||
static void launch_gated_delta_net(
|
||||
const float * q_d, const float * k_d, const float * v_d,
|
||||
const float * g_d, const float * b_d, const float * s_d,
|
||||
@@ -155,7 +177,7 @@ static void launch_gated_delta_net(
|
||||
int64_t sv1, int64_t sv2, int64_t sv3,
|
||||
int64_t sb1, int64_t sb2, int64_t sb3,
|
||||
int64_t neqk1, int64_t rq3,
|
||||
float scale, cudaStream_t stream) {
|
||||
float scale, int K, cudaStream_t stream) {
|
||||
//TODO: Add chunked kernel for even faster pre-fill
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const int num_warps = 4;
|
||||
@@ -169,29 +191,29 @@ static void launch_gated_delta_net(
|
||||
|
||||
switch (S_v) {
|
||||
case 16:
|
||||
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
gated_delta_net_cuda<16, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
break;
|
||||
case 32:
|
||||
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
gated_delta_net_cuda<32, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
break;
|
||||
case 64: {
|
||||
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
gated_delta_net_cuda<64, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
break;
|
||||
}
|
||||
case 128: {
|
||||
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
gated_delta_net_cuda<128, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int K = (int) src_state->ne[1];
|
||||
const bool keep_rs = K > 1;
|
||||
|
||||
if (kda) {
|
||||
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, stream);
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
}
|
||||
} else {
|
||||
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, stream);
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
|
||||
const int ne20 = op->src[2]->ne[0]; // S_v
|
||||
const int ne21 = op->src[2]->ne[1]; // H
|
||||
const int ne30 = op->src[3]->ne[0]; // G
|
||||
// state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int K = op->src[5]->ne[1];
|
||||
|
||||
const int nsg = op->src[2]->ne[0]/32;
|
||||
|
||||
@@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
|
||||
GGML_ASSERT(ne20 % 32 == 0);
|
||||
|
||||
snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
|
||||
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
|
||||
snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
@@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
|
||||
|
||||
ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
|
||||
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
|
||||
ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
|
||||
@@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32(
|
||||
|
||||
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
||||
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
||||
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
|
||||
|
||||
#if 1
|
||||
template<short NSG>
|
||||
@@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl(
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
#define S_v FC_gated_delta_net_ne20
|
||||
#define G FC_gated_delta_net_ne30
|
||||
#define K FC_gated_delta_net_K
|
||||
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B
|
||||
const uint i21 = tgpig.y; // H
|
||||
const uint i20 = tgpig.x*NSG + ty;
|
||||
const uint i23 = tgpig.z; // B (n_seqs)
|
||||
const uint i21 = tgpig.y; // H (head)
|
||||
const uint i20 = tgpig.x*NSG + ty; // row within S_v
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
|
||||
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
device const float * s_ptr = (device const float *) (s) + state_in_base;
|
||||
|
||||
float ls[NSG];
|
||||
|
||||
@@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl(
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
|
||||
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = (int)args.ne22 - (int)K;
|
||||
|
||||
// output state base offset: after attention scores
|
||||
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
|
||||
// output state per-slot size: S_v * S_v * H * n_seqs
|
||||
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
|
||||
// per-(seq,head) offset within a slot
|
||||
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
float s_k = 0.0f;
|
||||
|
||||
@@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl(
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
if (K > 1u) {
|
||||
const int target_slot = (int)t - shift;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
if (K == 1u) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
#undef G
|
||||
#undef K
|
||||
}
|
||||
|
||||
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
|
||||
|
||||
@@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants {
|
||||
uint32_t sb1, sb2, sb3;
|
||||
uint32_t neq1, rq3;
|
||||
float scale;
|
||||
uint32_t K;
|
||||
};
|
||||
|
||||
struct vk_op_ssm_scan_push_constants {
|
||||
@@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||
const ggml_tensor * src_q = dst->src[0];
|
||||
const ggml_tensor * src_v = dst->src[2];
|
||||
const ggml_tensor * src_beta = dst->src[4];
|
||||
const ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
@@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
|
||||
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const uint32_t K = (uint32_t)src_state->ne[1];
|
||||
|
||||
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
||||
@@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||
sv1, sv2, sv3,
|
||||
sb1, sb2, sb3,
|
||||
neq1, rq3,
|
||||
scale
|
||||
scale,
|
||||
K
|
||||
};
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
|
||||
@@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters {
|
||||
uint sb1, sb2, sb3;
|
||||
uint neq1, rq3;
|
||||
float scale;
|
||||
uint K;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
|
||||
@@ -101,13 +102,21 @@ void main() {
|
||||
const uint iq3 = seq_id / rq3;
|
||||
|
||||
const uint state_size = S_V * S_V;
|
||||
const uint state_base = (seq_id * H + head_id) * state_size;
|
||||
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
|
||||
const uint state_in_base = (seq_id * K * H + head_id) * state_size;
|
||||
// output state layout per slot: same per-(seq,head) offset as the single-slot case.
|
||||
const uint state_out_base = (seq_id * H + head_id) * state_size;
|
||||
const uint state_size_per_snap = state_size * H * n_seqs;
|
||||
|
||||
FLOAT_TYPE s_shard[ROWS_PER_LANE];
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
|
||||
s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]);
|
||||
}
|
||||
|
||||
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
|
||||
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = int(n_tokens) - int(K);
|
||||
|
||||
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
|
||||
for (uint t = 0; t < n_tokens; t++) {
|
||||
@@ -161,9 +170,21 @@ void main() {
|
||||
}
|
||||
|
||||
attn_off += S_V * H;
|
||||
|
||||
if (K > 1u) {
|
||||
const int target_slot = int(t) - shift;
|
||||
if (target_slot >= 0 && target_slot < int(K)) {
|
||||
const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
|
||||
if (K == 1u) {
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+7
-5
@@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net(
|
||||
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
|
||||
GGML_ASSERT(beta->ne[0] == 1);
|
||||
|
||||
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
|
||||
|
||||
// concat output and new_state into a single tensor
|
||||
// output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
|
||||
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
|
||||
// state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count.
|
||||
GGML_ASSERT(state->ne[0] == S_v * S_v * H);
|
||||
GGML_ASSERT(state->ne[2] == n_seqs);
|
||||
GGML_ASSERT(state->ne[3] == 1);
|
||||
const int64_t K = state->ne[1];
|
||||
const int64_t state_rows = K * S_v * n_seqs;
|
||||
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_GATED_DELTA_NET;
|
||||
|
||||
@@ -2114,7 +2114,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_BETA,
|
||||
MODEL_TENSOR.SSM_ALPHA,
|
||||
MODEL_TENSOR.SSM_OUT
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
# NextN/MTP tensors - preserved but unused
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ,
|
||||
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
|
||||
MODEL_TENSOR.NEXTN_ENORM,
|
||||
MODEL_TENSOR.NEXTN_HNORM,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
],
|
||||
MODEL_ARCH.QWEN35MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -2145,7 +2152,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_BETA,
|
||||
MODEL_TENSOR.SSM_ALPHA,
|
||||
MODEL_TENSOR.SSM_OUT
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
# NextN/MTP tensors - preserved but unused
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ,
|
||||
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
|
||||
MODEL_TENSOR.NEXTN_ENORM,
|
||||
MODEL_TENSOR.NEXTN_HNORM,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
],
|
||||
MODEL_ARCH.PLAMO: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
||||
@@ -198,6 +198,11 @@ extern "C" {
|
||||
LLAMA_SPLIT_MODE_TENSOR = 3,
|
||||
};
|
||||
|
||||
enum llama_context_type {
|
||||
LLAMA_CONTEXT_TYPE_DEFAULT = 0,
|
||||
LLAMA_CONTEXT_TYPE_MTP = 1,
|
||||
};
|
||||
|
||||
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
@@ -333,9 +338,11 @@ extern "C" {
|
||||
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
||||
uint32_t n_ubatch; // physical maximum batch size
|
||||
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
|
||||
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
enum llama_context_type ctx_type; // set the context type (e.g. MTP)
|
||||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||
enum llama_attention_type attention_type; // attention type to use for embeddings
|
||||
@@ -530,6 +537,7 @@ extern "C" {
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx);
|
||||
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
|
||||
|
||||
+19
-8
@@ -757,14 +757,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
|
||||
// These tensors only exist in the last layer(s) and are treated as output tensors
|
||||
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
|
||||
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
|
||||
// the model loader doesn't fault on the block index.
|
||||
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
// Nemotron 3 Super
|
||||
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
@@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_supports_rs_rollback(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_QWEN35:
|
||||
case LLM_ARCH_QWEN35MOE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_GROK:
|
||||
|
||||
@@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch);
|
||||
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||
bool llm_arch_is_diffusion (const llm_arch & arch);
|
||||
bool llm_arch_supports_sm_tensor(const llm_arch & arch);
|
||||
bool llm_arch_supports_rs_rollback(const llm_arch & arch);
|
||||
|
||||
+142
-15
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "ggml.h"
|
||||
#include "llama-arch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-io.h"
|
||||
@@ -21,6 +22,14 @@
|
||||
// llama_context
|
||||
//
|
||||
|
||||
static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) {
|
||||
switch (ctx_type) {
|
||||
case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT;
|
||||
case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP;
|
||||
}
|
||||
throw std::runtime_error("Unsupported ctx type");
|
||||
}
|
||||
|
||||
llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
llama_context_params params) :
|
||||
@@ -42,6 +51,13 @@ llama_context::llama_context(
|
||||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
||||
}
|
||||
|
||||
cparams.n_rs_seq = params.n_rs_seq;
|
||||
if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) {
|
||||
LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
|
||||
__func__, cparams.n_rs_seq);
|
||||
cparams.n_rs_seq = 0;
|
||||
}
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
|
||||
@@ -49,6 +65,7 @@ llama_context::llama_context(
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.embeddings_pre_norm = false;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
@@ -65,6 +82,8 @@ llama_context::llama_context(
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
||||
cparams.ctx_type = params.ctx_type;
|
||||
|
||||
// Initialize backend samplers here so they are part of the sampling graph
|
||||
// before the reserve passes run later in this function. This avoids a later
|
||||
// re-reserve when graph nodes change.
|
||||
@@ -206,6 +225,7 @@ llama_context::llama_context(
|
||||
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq);
|
||||
|
||||
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
||||
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||
@@ -278,6 +298,7 @@ llama_context::llama_context(
|
||||
/*.type_k =*/ params.type_k,
|
||||
/*.type_v =*/ params.type_v,
|
||||
/*.swa_full =*/ params.swa_full,
|
||||
/*.ctx_type= */ cparams.ctx_type,
|
||||
};
|
||||
|
||||
memory.reset(model.create_memory(params_mem, cparams));
|
||||
@@ -860,6 +881,33 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_pre_norm() {
|
||||
output_reorder();
|
||||
|
||||
return embd_pre_norm.data;
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
if (embd_pre_norm.data == nullptr) {
|
||||
throw std::runtime_error("no pre-norm embeddings");
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
const uint32_t n_embd = model.hparams.n_embd;
|
||||
return embd_pre_norm.data + j*n_embd;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
GGML_ABORT("fatal error");
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
@@ -1040,6 +1088,12 @@ void llama_context::set_embeddings(bool value) {
|
||||
//sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_embeddings_pre_norm(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
cparams.embeddings_pre_norm = value;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
@@ -1241,7 +1295,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
||||
}
|
||||
|
||||
int llama_context::encode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
|
||||
// so accept either present rather than requiring exactly one.
|
||||
GGML_ASSERT(batch_inp.token || batch_inp.embd);
|
||||
|
||||
if (batch_inp.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
@@ -1312,8 +1368,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
}
|
||||
}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
|
||||
|
||||
// extract logits
|
||||
if (logits.data && t_logits) {
|
||||
@@ -1379,6 +1436,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
}
|
||||
}
|
||||
|
||||
// extract pre-norm embeddings (hidden state before the final output norm)
|
||||
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
|
||||
}
|
||||
|
||||
// TODO: hacky solution
|
||||
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
||||
//cross.t_embd = t_embd;
|
||||
@@ -1531,7 +1598,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
|
||||
}
|
||||
|
||||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
|
||||
// so accept either present rather than requiring exactly one.
|
||||
GGML_ASSERT(batch_inp.token || batch_inp.embd);
|
||||
|
||||
if (!memory) {
|
||||
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
||||
@@ -1689,7 +1758,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
}
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
|
||||
const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
|
||||
@@ -1727,8 +1797,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
@@ -1809,6 +1880,20 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
}
|
||||
}
|
||||
|
||||
// extract pre-norm embeddings (hidden state before the final output norm)
|
||||
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
|
||||
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
|
||||
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
}
|
||||
|
||||
// Copy backend sampling output if this ubatch produced any sampling tensors.
|
||||
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
|
||||
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
||||
@@ -1893,10 +1978,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = vocab.n_tokens();
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_embd_out = hparams.n_embd_out();
|
||||
|
||||
bool has_logits = true;
|
||||
bool has_embd = cparams.embeddings;
|
||||
bool has_logits = true;
|
||||
bool has_embd = cparams.embeddings;
|
||||
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
|
||||
|
||||
// TODO: hacky enc-dec support
|
||||
if (model.arch == LLM_ARCH_T5) {
|
||||
@@ -1908,8 +1995,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
size_t backend_float_count = 0;
|
||||
size_t backend_token_count = 0;
|
||||
|
||||
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
|
||||
|
||||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
const bool has_sampling = !sampling.samplers.empty();
|
||||
@@ -1925,8 +2013,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
||||
const size_t new_size =
|
||||
(logits.size + embd.size + backend_float_count) * sizeof(float) +
|
||||
( backend_token_count) * sizeof(llama_token);
|
||||
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
|
||||
( backend_token_count) * sizeof(llama_token);
|
||||
|
||||
// alloc only when more than the current capacity is required
|
||||
// TODO: also consider shrinking the buffer
|
||||
@@ -1942,6 +2030,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
buf_output = nullptr;
|
||||
logits.data = nullptr;
|
||||
embd.data = nullptr;
|
||||
embd_pre_norm.data = nullptr;
|
||||
}
|
||||
|
||||
auto * buft = ggml_backend_cpu_buffer_type();
|
||||
@@ -1970,6 +2059,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd.size * sizeof(float);
|
||||
|
||||
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd_pre_norm.size * sizeof(float);
|
||||
|
||||
if (has_sampling) {
|
||||
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.logits.size * sizeof(float);
|
||||
@@ -2034,6 +2126,12 @@ void llama_context::output_reorder() {
|
||||
}
|
||||
}
|
||||
|
||||
if (embd_pre_norm.size > 0) {
|
||||
for (uint64_t k = 0; k < n_embd; k++) {
|
||||
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!sampling.samplers.empty()) {
|
||||
assert(sampling.logits.size > 0);
|
||||
assert(sampling.probs.size > 0);
|
||||
@@ -2121,7 +2219,7 @@ ggml_cgraph * llama_context::graph_reserve(
|
||||
|
||||
auto * res = gf_res_reserve.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
||||
const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type));
|
||||
|
||||
res->reset();
|
||||
|
||||
@@ -3100,7 +3198,7 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
auto * res = gf_res_prev.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type));
|
||||
|
||||
res->reset();
|
||||
|
||||
@@ -3201,8 +3299,10 @@ llama_context_params llama_context_default_params() {
|
||||
/*.n_batch =*/ 2048,
|
||||
/*.n_ubatch =*/ 512,
|
||||
/*.n_seq_max =*/ 1,
|
||||
/*.n_rs_seq =*/ 0,
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
||||
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
||||
@@ -3306,6 +3406,13 @@ llama_context * llama_init_from_model(
|
||||
model->hparams.pooling_type, params.pooling_type);
|
||||
}
|
||||
|
||||
if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP &&
|
||||
model->hparams.nextn_predict_layers == 0) {
|
||||
LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
auto * ctx = new llama_context(*model, params);
|
||||
return ctx;
|
||||
@@ -3347,6 +3454,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) {
|
||||
return ctx->n_seq_max();
|
||||
}
|
||||
|
||||
uint32_t llama_n_rs_seq(const llama_context * ctx) {
|
||||
return ctx->get_cparams().n_rs_seq;
|
||||
}
|
||||
|
||||
const llama_model * llama_get_model(const llama_context * ctx) {
|
||||
return &ctx->get_model();
|
||||
}
|
||||
@@ -3436,6 +3547,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
|
||||
ctx->set_embeddings_pre_norm(value);
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_embeddings_pre_norm();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_embeddings_pre_norm_ith(i);
|
||||
}
|
||||
|
||||
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
||||
return ctx->set_sampler(seq_id, smpl);
|
||||
}
|
||||
|
||||
@@ -84,6 +84,9 @@ struct llama_context {
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
float * get_embeddings_pre_norm();
|
||||
float * get_embeddings_pre_norm_ith(int32_t i);
|
||||
|
||||
llama_token * get_sampled_tokens() const;
|
||||
llama_token get_sampled_token_ith(int32_t idx);
|
||||
|
||||
@@ -107,6 +110,7 @@ struct llama_context {
|
||||
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
||||
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_pre_norm(bool value);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
@@ -278,6 +282,11 @@ private:
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
buffer_view<float> embd = {nullptr, 0};
|
||||
|
||||
// hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when cparams.embeddings_pre_norm is enabled and the model graph
|
||||
// sets llm_graph_result::t_h_pre_norm
|
||||
buffer_view<float> embd_pre_norm = {nullptr, 0};
|
||||
|
||||
struct sampling_info {
|
||||
// !samplers.empty() to check if any samplers are active
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
@@ -12,6 +12,7 @@ struct llama_cparams {
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
@@ -27,6 +28,7 @@ struct llama_cparams {
|
||||
float yarn_beta_slow;
|
||||
|
||||
bool embeddings;
|
||||
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
|
||||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
@@ -40,6 +42,7 @@ struct llama_cparams {
|
||||
bool kv_unified;
|
||||
bool pipeline_parallel;
|
||||
|
||||
enum llama_context_type ctx_type;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
|
||||
@@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model);
|
||||
LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i);
|
||||
|
||||
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// pre-norm embeddings (hidden state before the final output norm)
|
||||
//
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
+2
-1
@@ -2528,7 +2528,8 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||
int32_t rs_zero,
|
||||
const llm_graph_get_rows_fn & get_state_rows) const {
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
|
||||
GGML_UNUSED(rs_size);
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]);
|
||||
|
||||
// Clear a single state which will then be copied to the other cleared states.
|
||||
// Note that this is a no-op when the view is zero-sized.
|
||||
|
||||
@@ -32,6 +32,7 @@ enum llm_graph_type {
|
||||
LLM_GRAPH_TYPE_DEFAULT,
|
||||
LLM_GRAPH_TYPE_ENCODER,
|
||||
LLM_GRAPH_TYPE_DECODER,
|
||||
LLM_GRAPH_TYPE_DECODER_MTP,
|
||||
};
|
||||
|
||||
enum llm_ffn_op_type {
|
||||
@@ -644,6 +645,7 @@ public:
|
||||
ggml_tensor * get_logits() const { return t_logits; }
|
||||
ggml_tensor * get_embd() const { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
|
||||
ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; }
|
||||
|
||||
ggml_cgraph * get_gf() const { return gf; }
|
||||
ggml_context * get_ctx() const { return ctx_compute.get(); }
|
||||
@@ -672,6 +674,7 @@ public:
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm
|
||||
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
|
||||
@@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const {
|
||||
}
|
||||
|
||||
bool llama_hparams::has_kv(uint32_t il) const {
|
||||
if (kv_only_nextn) {
|
||||
// MTP head: only the trailing nextn_predict_layers blocks own a KV cache;
|
||||
// the leading trunk blocks are not executed in this graph.
|
||||
return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers);
|
||||
}
|
||||
|
||||
if (n_layer_kv_from_start >= 0) {
|
||||
if (il < (uint32_t) n_layer_kv_from_start) {
|
||||
return true;
|
||||
|
||||
@@ -92,6 +92,8 @@ struct llama_hparams {
|
||||
uint32_t moe_latent_size = 0;
|
||||
uint32_t nextn_predict_layers = 0;
|
||||
|
||||
bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches)
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
float f_norm_group_eps;
|
||||
|
||||
@@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_rs_seq,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
@@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max,
|
||||
n_rs_seq,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr
|
||||
|
||||
@@ -34,6 +34,7 @@ public:
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_rs_seq,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
|
||||
@@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_rs_seq,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
@@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max,
|
||||
n_rs_seq,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr
|
||||
|
||||
@@ -34,6 +34,7 @@ public:
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_rs_seq,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
|
||||
+104
-15
@@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent(
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_rs_seq,
|
||||
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
|
||||
@@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent(
|
||||
size = mem_size;
|
||||
used = 0;
|
||||
|
||||
this->n_rs_seq = n_rs_seq;
|
||||
rs_idx.assign(n_seq_max, 0);
|
||||
|
||||
cells.clear();
|
||||
cells.resize(mem_size);
|
||||
|
||||
@@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent(
|
||||
throw std::runtime_error("failed to create ggml context for rs cache");
|
||||
}
|
||||
|
||||
ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size);
|
||||
ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size);
|
||||
const uint32_t n_rows = mem_size * (1 + n_rs_seq);
|
||||
ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows);
|
||||
ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows);
|
||||
ggml_format_name(r, "cache_r_l%d", i);
|
||||
ggml_format_name(s, "cache_s_l%d", i);
|
||||
r_l[i] = r;
|
||||
@@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent(
|
||||
const size_t memory_size_r = size_r_bytes();
|
||||
const size_t memory_size_s = size_s_bytes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq,
|
||||
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
||||
}
|
||||
@@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
std::fill(rs_idx.begin(), rs_idx.end(), 0);
|
||||
}
|
||||
|
||||
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
|
||||
uint32_t new_head = size;
|
||||
|
||||
if (p0 < 0) {
|
||||
@@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
const bool rm_all = p0 == 0 && p1 == std::numeric_limits<llama_pos>::max();
|
||||
if (rm_all) {
|
||||
if (seq_id >= 0) {
|
||||
set_rs_idx(seq_id, 0);
|
||||
} else {
|
||||
std::fill(rs_idx.begin(), rs_idx.end(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// models like Mamba or RWKV can't have a state partially erased at the end
|
||||
// of the sequence because their state isn't preserved for previous tokens
|
||||
if (seq_id >= (int64_t) size) {
|
||||
@@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||
if (0 <= seq_id) {
|
||||
int32_t & tail_id = cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
const auto & cell = cells[tail_id];
|
||||
// partial intersection is invalid if it includes the final pos
|
||||
auto & cell = cells[tail_id];
|
||||
|
||||
// partial rollback via per-token snapshot index (bounded by n_rs_seq)
|
||||
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
||||
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
|
||||
const llama_pos rollback = cell.pos - (p0 - 1);
|
||||
if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) {
|
||||
set_rs_idx(seq_id, (uint32_t) rollback);
|
||||
cell.pos = p0 - 1;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// invalidate tails which will be cleared
|
||||
@@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) {
|
||||
if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) {
|
||||
return;
|
||||
}
|
||||
rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx;
|
||||
}
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
@@ -703,6 +731,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges_data; // logical source row ranges
|
||||
uint32_t cell_count = 0;
|
||||
|
||||
// Count the number of cells with the specified seq_id
|
||||
@@ -712,6 +741,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
const auto & cell = cells[i];
|
||||
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
||||
++cell_count;
|
||||
uint32_t rs_idx_cur = 0;
|
||||
|
||||
if (n_rs_seq != 0) {
|
||||
if (seq_id != -1) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size());
|
||||
rs_idx_cur = rs_idx[seq_id];
|
||||
} else {
|
||||
bool has_rs_idx = false;
|
||||
for (const llama_seq_id cell_seq_id : cell.seq_id) {
|
||||
GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size());
|
||||
|
||||
const uint32_t seq_rs_idx = rs_idx[cell_seq_id];
|
||||
if (!has_rs_idx) {
|
||||
rs_idx_cur = seq_rs_idx;
|
||||
has_rs_idx = true;
|
||||
} else if (rs_idx_cur != seq_rs_idx) {
|
||||
GGML_ABORT("cannot write shared recurrent state with different rollback indices");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i);
|
||||
if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) {
|
||||
cell_ranges_data.emplace_back(cell_id, cell_id + 1);
|
||||
} else {
|
||||
cell_ranges_data.back().second++;
|
||||
}
|
||||
|
||||
if (cell_range_begin == size) {
|
||||
cell_range_begin = i;
|
||||
}
|
||||
@@ -726,7 +784,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
cell_ranges.emplace_back(cell_range_begin, size);
|
||||
}
|
||||
|
||||
if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) {
|
||||
if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) {
|
||||
GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n");
|
||||
}
|
||||
|
||||
@@ -737,10 +795,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
}
|
||||
GGML_ASSERT(cell_count == cell_count_check);
|
||||
|
||||
cell_count_check = 0;
|
||||
for (const auto & range : cell_ranges_data) {
|
||||
cell_count_check += range.second - range.first;
|
||||
}
|
||||
GGML_ASSERT(cell_count == cell_count_check);
|
||||
|
||||
io.write(&cell_count, sizeof(cell_count));
|
||||
|
||||
state_write_meta(io, cell_ranges, seq_id);
|
||||
state_write_data(io, cell_ranges);
|
||||
state_write_data(io, cell_ranges_data);
|
||||
}
|
||||
|
||||
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
@@ -762,6 +826,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
}
|
||||
throw std::runtime_error("failed to restore kv cache");
|
||||
}
|
||||
|
||||
if (n_rs_seq != 0) {
|
||||
if (seq_id == -1) {
|
||||
std::fill(rs_idx.begin(), rs_idx.end(), 0);
|
||||
} else {
|
||||
set_rs_idx(seq_id, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
@@ -804,7 +876,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
||||
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
io.write(&r_size_row, sizeof(r_size_row));
|
||||
|
||||
// Write each range of cells of r_size_row length
|
||||
// Write each logical cell row range. With pending recurrent rollback,
|
||||
// the logical current state may live in a rollback snapshot plane.
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * r_size_row;
|
||||
@@ -825,7 +898,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
||||
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
io.write(&s_size_row, sizeof(s_size_row));
|
||||
|
||||
// Write each range of S tensor rows
|
||||
// Write each logical cell row range. With pending recurrent rollback,
|
||||
// the logical current state may live in a rollback snapshot plane.
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * s_size_row;
|
||||
@@ -852,9 +926,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
||||
// Write GQA embedding size
|
||||
io.write(&n_embd_s, sizeof(n_embd_s));
|
||||
|
||||
// For each row, we get the element values of each cell
|
||||
// For each row, we get the element values of each logical cell
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
// Write each range of cells of s_size_el length
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
|
||||
@@ -1163,5 +1236,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
||||
}
|
||||
|
||||
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
||||
return mem->cells[i + mem->head].src0;
|
||||
const uint32_t cell_idx = i + mem->head;
|
||||
const int32_t src0 = mem->cells[cell_idx].src0;
|
||||
|
||||
if (mem->n_rs_seq == 0) {
|
||||
return src0;
|
||||
}
|
||||
|
||||
uint32_t idx = 0;
|
||||
if (!mem->cells[cell_idx].seq_id.empty()) {
|
||||
const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin();
|
||||
if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) {
|
||||
idx = mem->rs_idx[seq];
|
||||
// reset rollback idx
|
||||
mem->rs_idx[seq] = 0;
|
||||
}
|
||||
}
|
||||
return (int32_t)(idx * mem->size) + src0;
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ public:
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_rs_seq,
|
||||
const layer_filter_cb & filter);
|
||||
|
||||
~llama_memory_recurrent() = default;
|
||||
@@ -69,6 +70,13 @@ public:
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups
|
||||
uint32_t n_rs_seq = 0;
|
||||
// per-seq rollback index
|
||||
std::vector<uint32_t> rs_idx;
|
||||
|
||||
void set_rs_idx(llama_seq_id seq_id, uint32_t idx);
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-graph.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
@@ -20,6 +21,8 @@ struct llama_memory_params {
|
||||
|
||||
// use full-size SWA cache
|
||||
bool swa_full;
|
||||
|
||||
llama_context_type ctx_type;
|
||||
};
|
||||
|
||||
enum llama_memory_status {
|
||||
|
||||
@@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void llama_model_loader::done_getting_tensors() const {
|
||||
if (n_created != n_tensors) {
|
||||
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
|
||||
void llama_model_loader::done_getting_tensors(bool partial) const {
|
||||
if (n_created > n_tensors) {
|
||||
throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created));
|
||||
}
|
||||
if (n_created < n_tensors) {
|
||||
if (!partial) {
|
||||
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n",
|
||||
__func__, n_created, n_tensors);
|
||||
}
|
||||
if (n_tensors_moved > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n",
|
||||
|
||||
@@ -184,7 +184,7 @@ struct llama_model_loader {
|
||||
|
||||
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
|
||||
|
||||
void done_getting_tensors() const;
|
||||
void done_getting_tensors(bool partial = false) const;
|
||||
|
||||
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
|
||||
|
||||
|
||||
+27
-3
@@ -1947,6 +1947,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
// checks
|
||||
default:
|
||||
{
|
||||
// The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain
|
||||
// attention KV cache for the MTP context instead of the hybrid wrapper.
|
||||
const bool mtp_on_hybrid_qwen35 =
|
||||
params.ctx_type == LLAMA_CONTEXT_TYPE_MTP &&
|
||||
(arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE);
|
||||
|
||||
if (llm_arch_is_recurrent(arch)) {
|
||||
res = new llama_memory_recurrent(
|
||||
*this,
|
||||
@@ -1955,8 +1961,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.offload_kqv,
|
||||
std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
cparams.n_seq_max,
|
||||
cparams.n_rs_seq,
|
||||
nullptr);
|
||||
} else if (llm_arch_is_hybrid(arch)) {
|
||||
} else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) {
|
||||
// The main difference between hybrid architectures is the
|
||||
// layer filters, so pick the right one here
|
||||
llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
|
||||
@@ -1971,6 +1978,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
filter_recr = [&](int32_t il) {
|
||||
return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
|
||||
};
|
||||
} else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) {
|
||||
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
filter_attn = [&, n_main](int32_t il) {
|
||||
return (uint32_t)il < n_main && !hparams.is_recurrent(il);
|
||||
};
|
||||
filter_recr = [&, n_main](int32_t il) {
|
||||
return (uint32_t)il < n_main && hparams.is_recurrent(il);
|
||||
};
|
||||
}
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
@@ -1988,6 +2003,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
/* recurrent_type_s */ GGML_TYPE_F32,
|
||||
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* n_rs_seq */ cparams.n_rs_seq,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
@@ -2006,6 +2022,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* n_rs_seq */ cparams.n_rs_seq,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
@@ -2013,6 +2030,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
}
|
||||
} else {
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
llama_kv_cache::layer_filter_cb filter = nullptr;
|
||||
|
||||
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
|
||||
reuse = [&](int32_t il) {
|
||||
@@ -2024,6 +2042,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
};
|
||||
}
|
||||
|
||||
if (mtp_on_hybrid_qwen35) {
|
||||
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; };
|
||||
}
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
@@ -2039,7 +2062,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
1,
|
||||
nullptr,
|
||||
filter,
|
||||
reuse);
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
@@ -2056,7 +2079,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
1,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type,
|
||||
nullptr,
|
||||
filter,
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
@@ -2159,6 +2182,7 @@ int32_t llama_model_n_swa(const llama_model * model) {
|
||||
return model->hparams.n_swa;
|
||||
}
|
||||
|
||||
|
||||
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
|
||||
return model->hparams.n_cls_out;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
// utility to get one slice from the third dimension
|
||||
// input dim: [x, y, c, b]
|
||||
@@ -397,7 +398,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
|
||||
// K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net.
|
||||
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs);
|
||||
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d);
|
||||
if (n_tokens == 1) {
|
||||
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
|
||||
} else {
|
||||
@@ -443,3 +446,141 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
|
||||
return build_delta_net_chunking(q, k, v, g, b, s, il);
|
||||
}
|
||||
|
||||
bool llm_build_delta_net_base::keep_rs() const {
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
return cparams.n_rs_seq > 0
|
||||
&& n_seq_tokens > 1
|
||||
&& (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_delta_net_base::build_conv_state(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * conv_states_all,
|
||||
ggml_tensor * qkv_mixed,
|
||||
int64_t conv_kernel_size,
|
||||
int64_t conv_channels,
|
||||
int il) {
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
const uint32_t mem_size = mctx_cur->get_size();
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const bool keep = keep_rs();
|
||||
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
|
||||
if (!keep) {
|
||||
ggml_tensor * last_conv_states =
|
||||
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
|
||||
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
|
||||
cb(last_conv_states, "last_conv_states", il);
|
||||
|
||||
ggml_tensor * state_update_target =
|
||||
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
|
||||
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
||||
cb(state_update_target, "state_update_target", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
} else {
|
||||
const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
|
||||
const size_t row_size = row_count * ggml_element_size(conv_states_all);
|
||||
for (int64_t t = 1; t <= n_seq_tokens; ++t) {
|
||||
const uint32_t slot = (uint32_t)(n_seq_tokens - t);
|
||||
ggml_tensor * src =
|
||||
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs,
|
||||
conv_input->nb[1], conv_input->nb[2],
|
||||
t * ggml_element_size(conv_input));
|
||||
ggml_tensor * dst =
|
||||
ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs,
|
||||
conv_states_all->nb[1],
|
||||
((size_t) slot * mem_size + kv_head) * row_size);
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
|
||||
}
|
||||
}
|
||||
|
||||
return conv_input;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * ssm_states_all,
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il) {
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
const uint32_t mem_size = mctx_cur->get_size();
|
||||
|
||||
const int64_t S_v = s->ne[0];
|
||||
const int64_t H_v = s->ne[2];
|
||||
const int64_t n_seqs = s->ne[3];
|
||||
const int64_t n_seq_tokens = q->ne[2];
|
||||
|
||||
if (!keep_rs()) {
|
||||
auto attn_out = build_delta_net(q, k, v, g, b, s, il);
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
const int64_t D = S_v * S_v * H_v;
|
||||
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
|
||||
|
||||
// TODO: remove pad + simplify
|
||||
ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
|
||||
ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0);
|
||||
|
||||
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d);
|
||||
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
|
||||
|
||||
const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs;
|
||||
const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs;
|
||||
|
||||
ggml_tensor * output = ggml_view_4d(ctx0, gdn_out,
|
||||
S_v, H_v, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(gdn_out->type, S_v),
|
||||
ggml_row_size(gdn_out->type, S_v * H_v),
|
||||
ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens),
|
||||
0);
|
||||
cb(output, "attn_output", il);
|
||||
|
||||
const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all);
|
||||
for (int64_t k_i = 0; k_i < K; ++k_i) {
|
||||
const uint32_t cache_slot = (uint32_t) (K - 1 - k_i);
|
||||
ggml_tensor * src = ggml_view_4d(ctx0, gdn_out,
|
||||
S_v, S_v, H_v, n_seqs,
|
||||
ggml_row_size(gdn_out->type, S_v),
|
||||
ggml_row_size(gdn_out->type, S_v * S_v),
|
||||
ggml_row_size(gdn_out->type, S_v * S_v * H_v),
|
||||
ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap));
|
||||
ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
|
||||
hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
((size_t) cache_slot * mem_size + kv_head) * row_size);
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
+35
-1
@@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context {
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
|
||||
// use the ggml_gated_delta_net fused operator
|
||||
// use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs))
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
@@ -65,6 +65,32 @@ struct llm_build_delta_net_base : public llm_graph_context {
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
|
||||
// true when speculative rollback is enabled and the batch fits in the rs cache
|
||||
bool keep_rs() const;
|
||||
|
||||
// read conv state from cache, concat with qkv_mixed, write back (single slot or per-token)
|
||||
// qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs)
|
||||
ggml_tensor * build_conv_state(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * conv_states_all,
|
||||
ggml_tensor * qkv_mixed,
|
||||
int64_t conv_kernel_size,
|
||||
int64_t conv_channels,
|
||||
int il);
|
||||
|
||||
// run delta-net attention and write the new recurrent state(s) back to ssm_states_all
|
||||
// s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs)
|
||||
ggml_tensor * build_recurrent_attn(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * ssm_states_all,
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
};
|
||||
|
||||
struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
@@ -1739,6 +1765,10 @@ struct llama_model_qwen35 : public llama_model_base {
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
struct graph_mtp : public llm_graph_context {
|
||||
graph_mtp(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
@@ -1781,6 +1811,10 @@ struct llama_model_qwen35moe : public llama_model_base {
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
struct graph_mtp : public llm_graph_context {
|
||||
graph_mtp(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
|
||||
};
|
||||
|
||||
|
||||
+232
-78
@@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||
|
||||
// Mark recurrent layers (linear attention layers)
|
||||
// NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
// Mark recurrent layers (linear attention layers). MTP layers are dense
|
||||
// attention-only and must be flagged non-recurrent.
|
||||
{
|
||||
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
uint32_t full_attn_interval = 4;
|
||||
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
|
||||
hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
|
||||
}
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
switch (hparams.n_layer - hparams.nextn_predict_layers) {
|
||||
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break;
|
||||
case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break;
|
||||
case 64: type = LLM_TYPE_27B; break;
|
||||
@@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_qwen35::load_arch_tensors(llama_model_loader &) {
|
||||
void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
|
||||
const bool mtp_only = (hparams.nextn_predict_layers > 0) &&
|
||||
(ml.get_weight("blk.0.attn_norm.weight") == nullptr);
|
||||
const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
@@ -43,50 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
// Calculate dimensions from hyperparameters
|
||||
const int64_t head_k_dim = hparams.ssm_d_state;
|
||||
const int64_t head_v_dim = hparams.ssm_d_state;
|
||||
const int64_t n_k_heads = hparams.ssm_n_group;
|
||||
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
auto load_block_trunk = [&](int il, int flags) {
|
||||
auto & layer = layers[il];
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
// Calculate dimensions from hyperparameters
|
||||
const int64_t head_k_dim = hparams.ssm_d_state;
|
||||
const int64_t head_v_dim = hparams.ssm_d_state;
|
||||
const int64_t n_k_heads = hparams.ssm_n_group;
|
||||
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags);
|
||||
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
if (!hparams.is_recurrent(il)) {
|
||||
// Attention layers
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags);
|
||||
|
||||
// Q/K normalization for attention layers
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags);
|
||||
} else {
|
||||
// Linear attention (gated delta net) specific tensors
|
||||
// Create tensors with calculated dimensions
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags);
|
||||
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags);
|
||||
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags);
|
||||
}
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags);
|
||||
};
|
||||
|
||||
auto load_block_mtp = [&](int il) {
|
||||
auto & layer = layers[il];
|
||||
|
||||
// MTP block looks like a full-attention Qwen3.5 decoder block.
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0);
|
||||
|
||||
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0);
|
||||
|
||||
// NextN-specific tensors that define the MTP block.
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0);
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
};
|
||||
|
||||
for (int i = 0; i < (int) n_main; ++i) {
|
||||
load_block_trunk(i, trunk_flags);
|
||||
}
|
||||
for (int i = (int) n_main; i < n_layer; ++i) {
|
||||
load_block_mtp(i);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const {
|
||||
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
|
||||
return std::make_unique<graph_mtp>(*this, params);
|
||||
}
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
@@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
|
||||
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
@@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
@@ -160,6 +208,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
@@ -297,8 +348,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
|
||||
const int64_t head_v_dim = d_inner / num_v_heads;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
@@ -328,41 +377,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
|
||||
|
||||
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Get convolution states from cache
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// Build the convolution states tensor
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
|
||||
// Calculate convolution kernel size
|
||||
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
||||
const int64_t conv_kernel_size = conv_kernel->ne[0];
|
||||
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
|
||||
// Update convolution state cache
|
||||
// Extract the last (conv_kernel_size - 1) states from conv_input
|
||||
ggml_tensor * last_conv_states =
|
||||
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
|
||||
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
|
||||
cb(last_conv_states, "last_conv_states", il);
|
||||
|
||||
ggml_tensor * state_update_target =
|
||||
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
|
||||
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
||||
cb(state_update_target, "state_update_target", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
|
||||
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
|
||||
@@ -413,7 +435,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
|
||||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
// note: need explicit repeat only if we are not using the fused GDN
|
||||
// note: need explicit repeat only if we are not using the fused GDN.
|
||||
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -424,18 +446,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -471,3 +482,146 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series
|
||||
llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params)
|
||||
: llm_graph_context(params) {
|
||||
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0");
|
||||
GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block");
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
// The MTP block lives at the source file's original layer index.
|
||||
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
|
||||
GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
|
||||
GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
|
||||
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
|
||||
ggml_set_input(inp->embd);
|
||||
ggml_set_name(inp->embd, "mtp_h_input");
|
||||
|
||||
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
|
||||
|
||||
ggml_tensor * h_input = inp->embd;
|
||||
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
|
||||
cb(tok_embd, "mtp_tok_embd", il);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(h_norm, "mtp_hnorm", il);
|
||||
|
||||
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(e_norm, "mtp_enorm", il);
|
||||
|
||||
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
|
||||
cb(concat, "mtp_concat", il);
|
||||
|
||||
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
|
||||
cb(cur, "mtp_eh_proj", il);
|
||||
|
||||
ggml_tensor * inpSA = cur;
|
||||
|
||||
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "mtp_attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
|
||||
cb(Qcur_full, "mtp_Qcur_full", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
|
||||
n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
|
||||
0);
|
||||
Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "mtp_Qcur_normed", il);
|
||||
|
||||
ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
|
||||
n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
|
||||
ggml_element_size(Qcur_full) * n_embd_head);
|
||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
cb(gate, "mtp_gate", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "mtp_Kcur_normed", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
cb(Vcur, "mtp_Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f
|
||||
? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
nullptr, nullptr, nullptr,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "mtp_attn_pregate", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
|
||||
cur = build_lora_mm(layer.wo, cur, layer.wo_s);
|
||||
cb(cur, "mtp_attn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "mtp_attn_residual", il);
|
||||
|
||||
ggml_tensor * ffn_residual = cur;
|
||||
cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "mtp_attn_post_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
layer.ffn_up, nullptr, layer.ffn_up_s,
|
||||
layer.ffn_gate, nullptr, layer.ffn_gate_s,
|
||||
layer.ffn_down, nullptr, layer.ffn_down_s,
|
||||
nullptr,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "mtp_ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
|
||||
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
|
||||
// (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.)
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
|
||||
? layer.nextn.shared_head_norm
|
||||
: model.output_norm;
|
||||
GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm");
|
||||
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "mtp_shared_head_norm", -1);
|
||||
|
||||
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
|
||||
GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)");
|
||||
cur = build_lora_mm(head_w, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
+279
-83
@@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||
|
||||
// Mark recurrent layers (linear attention layers)
|
||||
// NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
|
||||
|
||||
// Mark recurrent layers (linear attention layers). MTP layers are dense
|
||||
// attention-only and must be flagged non-recurrent.
|
||||
{
|
||||
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
uint32_t full_attn_interval = 4;
|
||||
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
|
||||
hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
|
||||
}
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
switch (hparams.n_layer - hparams.nextn_predict_layers) {
|
||||
case 40: type = LLM_TYPE_35B_A3B; break;
|
||||
case 48: type = LLM_TYPE_122B_A10B; break;
|
||||
case 60: type = LLM_TYPE_397B_A17B; break;
|
||||
@@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) {
|
||||
void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) {
|
||||
LLAMA_LOAD_LOCALS;
|
||||
|
||||
const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
|
||||
const bool mtp_only = (hparams.nextn_predict_layers > 0) &&
|
||||
(ml.get_weight("blk.0.attn_norm.weight") == nullptr);
|
||||
const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
@@ -46,60 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
auto load_block_trunk = [&](int il, int flags) {
|
||||
auto & layer = layers[il];
|
||||
|
||||
// Calculate dimensions from hyperparameters
|
||||
const int64_t head_k_dim = hparams.ssm_d_state;
|
||||
const int64_t head_v_dim = hparams.ssm_d_state;
|
||||
const int64_t n_k_heads = hparams.ssm_n_group;
|
||||
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
// Calculate dimensions from hyperparameters
|
||||
const int64_t head_k_dim = hparams.ssm_d_state;
|
||||
const int64_t head_v_dim = hparams.ssm_d_state;
|
||||
const int64_t n_k_heads = hparams.ssm_n_group;
|
||||
const int64_t n_v_heads = hparams.ssm_dt_rank;
|
||||
const int64_t key_dim = head_k_dim * n_k_heads;
|
||||
const int64_t value_dim = head_v_dim * n_v_heads;
|
||||
const int64_t conv_dim = key_dim * 2 + value_dim;
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags);
|
||||
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
if (!hparams.is_recurrent(il)) {
|
||||
// Attention layers
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags);
|
||||
|
||||
// Q/K normalization for attention layers
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags);
|
||||
} else {
|
||||
// Linear attention (gated delta net) specific tensors
|
||||
// Create tensors with calculated dimensions
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
|
||||
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags);
|
||||
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags);
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags);
|
||||
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags);
|
||||
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags);
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags);
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags);
|
||||
}
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
|
||||
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
|
||||
// Routed experts
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags);
|
||||
create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags);
|
||||
|
||||
// Shared experts
|
||||
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags);
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags);
|
||||
};
|
||||
|
||||
auto load_block_mtp = [&](int il) {
|
||||
auto & layer = layers[il];
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
|
||||
|
||||
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
|
||||
// MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN.
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0);
|
||||
|
||||
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0);
|
||||
|
||||
// Routed experts
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0);
|
||||
create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0);
|
||||
|
||||
// Shared experts
|
||||
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0);
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0);
|
||||
|
||||
// NextN-specific tensors that define the MTP block.
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0);
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
};
|
||||
|
||||
for (int i = 0; i < (int) n_main; ++i) {
|
||||
load_block_trunk(i, trunk_flags);
|
||||
}
|
||||
for (int i = (int) n_main; i < n_layer; ++i) {
|
||||
load_block_mtp(i);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<llm_graph_context> llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const {
|
||||
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
|
||||
return std::make_unique<graph_mtp>(*this, params);
|
||||
}
|
||||
return std::make_unique<graph>(*this, params);
|
||||
}
|
||||
|
||||
@@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
|
||||
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
@@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
@@ -173,6 +231,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
@@ -310,8 +371,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
|
||||
const int64_t head_v_dim = d_inner / num_v_heads;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
@@ -341,41 +400,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
|
||||
|
||||
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Get convolution states from cache
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// Build the convolution states tensor
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
|
||||
// Calculate convolution kernel size
|
||||
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
||||
const int64_t conv_kernel_size = conv_kernel->ne[0];
|
||||
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
|
||||
// Update convolution state cache
|
||||
// Extract the last (conv_kernel_size - 1) states from conv_input
|
||||
ggml_tensor * last_conv_states =
|
||||
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
|
||||
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
|
||||
cb(last_conv_states, "last_conv_states", il);
|
||||
|
||||
ggml_tensor * state_update_target =
|
||||
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
|
||||
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
||||
cb(state_update_target, "state_update_target", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
|
||||
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
|
||||
@@ -426,7 +458,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
|
||||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
// note: need explicit repeat only if we are not using the fused GDN
|
||||
// note: need explicit repeat only if we are not using the fused GDN.
|
||||
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -437,18 +469,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -525,3 +546,178 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE
|
||||
llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params)
|
||||
: llm_graph_context(params) {
|
||||
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0");
|
||||
GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block");
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
|
||||
GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
|
||||
GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
|
||||
GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp");
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
|
||||
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
|
||||
ggml_set_input(inp->embd);
|
||||
ggml_set_name(inp->embd, "mtp_h_input");
|
||||
|
||||
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
|
||||
|
||||
ggml_tensor * h_input = inp->embd;
|
||||
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
|
||||
cb(tok_embd, "mtp_tok_embd", il);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(h_norm, "mtp_hnorm", il);
|
||||
|
||||
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(e_norm, "mtp_enorm", il);
|
||||
|
||||
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
|
||||
cb(concat, "mtp_concat", il);
|
||||
|
||||
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
|
||||
cb(cur, "mtp_eh_proj", il);
|
||||
|
||||
ggml_tensor * inpSA = cur;
|
||||
|
||||
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "mtp_attn_norm", il);
|
||||
|
||||
ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
|
||||
cb(Qcur_full, "mtp_Qcur_full", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
|
||||
n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
|
||||
0);
|
||||
Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "mtp_Qcur_normed", il);
|
||||
|
||||
ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
|
||||
n_embd_head, n_head, n_tokens,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2,
|
||||
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
|
||||
ggml_element_size(Qcur_full) * n_embd_head);
|
||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
cb(gate, "mtp_gate", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "mtp_Kcur_normed", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
cb(Vcur, "mtp_Vcur", il);
|
||||
|
||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f
|
||||
? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
nullptr, nullptr, nullptr,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "mtp_attn_pregate", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
|
||||
cur = build_lora_mm(layer.wo, cur, layer.wo_s);
|
||||
cb(cur, "mtp_attn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "mtp_attn_residual", il);
|
||||
|
||||
ggml_tensor * ffn_residual = cur;
|
||||
cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "mtp_attn_post_norm", il);
|
||||
|
||||
// MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe).
|
||||
ggml_tensor * moe_out =
|
||||
build_moe_ffn(cur,
|
||||
layer.ffn_gate_inp,
|
||||
layer.ffn_up_exps,
|
||||
layer.ffn_gate_exps,
|
||||
layer.ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
|
||||
nullptr, layer.ffn_gate_up_exps,
|
||||
layer.ffn_up_exps_s,
|
||||
layer.ffn_gate_exps_s,
|
||||
layer.ffn_down_exps_s);
|
||||
cb(moe_out, "mtp_ffn_moe_out", il);
|
||||
|
||||
if (layer.ffn_up_shexp != nullptr) {
|
||||
ggml_tensor * ffn_shexp =
|
||||
build_ffn(cur,
|
||||
layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s,
|
||||
layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s,
|
||||
layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s,
|
||||
nullptr,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "mtp_ffn_shexp", il);
|
||||
|
||||
ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur);
|
||||
shared_gate = ggml_sigmoid(ctx0, shared_gate);
|
||||
cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il);
|
||||
|
||||
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
|
||||
cb(ffn_shexp, "mtp_ffn_shexp_gated", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
} else {
|
||||
cur = moe_out;
|
||||
}
|
||||
cb(cur, "mtp_ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
|
||||
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
|
||||
? layer.nextn.shared_head_norm
|
||||
: model.output_norm;
|
||||
GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm");
|
||||
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(cur, "mtp_shared_head_norm", -1);
|
||||
|
||||
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
|
||||
GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)");
|
||||
cur = build_lora_mm(head_w, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear(
|
||||
const int64_t head_v_dim = d_inner / num_v_heads;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
@@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear(
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Get convolution states from cache
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// Build the convolution states tensor
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
|
||||
// Calculate convolution kernel size
|
||||
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
||||
const int64_t conv_kernel_size = conv_kernel->ne[0];
|
||||
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
|
||||
// Update convolution state cache
|
||||
// Extract the last (conv_kernel_size - 1) states from conv_input
|
||||
ggml_tensor * last_conv_states =
|
||||
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
|
||||
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
|
||||
cb(last_conv_states, "last_conv_states", il);
|
||||
|
||||
ggml_tensor * state_update_target =
|
||||
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
|
||||
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
||||
cb(state_update_target, "state_update_target", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
|
||||
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
|
||||
@@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear(
|
||||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
@@ -252,6 +252,9 @@ llama_build_and_test(test-backend-sampler.cpp LABEL "model")
|
||||
llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -m "${MODEL_DEST}")
|
||||
set_tests_properties(test-state-restore-fragmented PROPERTIES FIXTURES_REQUIRED test-download-model)
|
||||
|
||||
llama_build_and_test(test-recurrent-state-rollback.cpp LABEL "model" ARGS -m "${MODEL_DEST}")
|
||||
set_tests_properties(test-recurrent-state-rollback PROPERTIES FIXTURES_REQUIRED test-download-model)
|
||||
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
# these tests use the backends directly and cannot be built with dynamic loading
|
||||
llama_build_and_test(test-barrier.cpp)
|
||||
|
||||
@@ -3832,16 +3832,17 @@ struct test_gated_delta_net : public test_case {
|
||||
const int v_repeat;
|
||||
const bool permuted;
|
||||
const bool kda;
|
||||
const int64_t K; // snapshot slot count: 1 = final-only, >1 = last K states
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda);
|
||||
return VARS_TO_STR9(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda, K);
|
||||
}
|
||||
|
||||
test_gated_delta_net(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1,
|
||||
int v_repeat = 1, bool permuted = false, bool kda = false)
|
||||
int v_repeat = 1, bool permuted = false, bool kda = false, int64_t K = 1)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs),
|
||||
v_repeat(v_repeat), permuted(permuted), kda(kda) {}
|
||||
v_repeat(v_repeat), permuted(permuted), kda(kda), K(K) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q;
|
||||
@@ -3863,7 +3864,7 @@ struct test_gated_delta_net : public test_case {
|
||||
const int64_t g_ne0 = kda ? head_size : 1;
|
||||
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_3d(ctx, type, head_size * v_repeat * head_size * head_count, K, n_seqs);
|
||||
ggml_set_name(g, "g");
|
||||
ggml_set_name(beta, "beta");
|
||||
ggml_set_name(state, "state");
|
||||
@@ -9034,6 +9035,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true));
|
||||
|
||||
// K > 1: output keeps the last min(n_tokens, K) per-token snapshots in the trailing K-token region.
|
||||
// exact-match cases (K == n_seq_tokens):
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 2, 1, 1, false, false, /*K=*/2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, false, /*K=*/4));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, false, /*K=*/4));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 128, 4, 1, 1, false, false, /*K=*/4));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true, /*K=*/4));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true, /*K=*/4));
|
||||
// overflow: n_tokens > K — only the last K snapshots kept.
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
|
||||
static llama_context * make_ctx(const common_params & params, llama_model * model) {
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
cparams.n_seq_max = 1;
|
||||
cparams.n_rs_seq = 8;
|
||||
cparams.n_batch = std::max(cparams.n_batch, (uint32_t) (cparams.n_rs_seq + 1));
|
||||
cparams.n_ubatch = std::max(cparams.n_ubatch, (uint32_t) (cparams.n_rs_seq + 1));
|
||||
return llama_init_from_model(model, cparams);
|
||||
}
|
||||
|
||||
static bool decode_tokens(llama_context * ctx, const std::vector<llama_token> & tokens, uint32_t count) {
|
||||
llama_batch batch = llama_batch_init(count, 0, 1);
|
||||
for (uint32_t pos = 0; pos < count; ++pos) {
|
||||
common_batch_add(batch, tokens[pos], pos, { 0 }, false);
|
||||
}
|
||||
const bool ok = llama_decode(ctx, batch) == 0;
|
||||
llama_batch_free(batch);
|
||||
return ok;
|
||||
}
|
||||
|
||||
static bool decode_one(llama_context * ctx, llama_token tok, llama_pos pos) {
|
||||
llama_batch batch = llama_batch_init(1, 0, 1);
|
||||
common_batch_add(batch, tok, pos, { 0 }, true);
|
||||
const bool ok = llama_decode(ctx, batch) == 0;
|
||||
llama_batch_free(batch);
|
||||
return ok;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
common_params params;
|
||||
params.sampling.seed = 1234;
|
||||
params.n_predict = 1;
|
||||
|
||||
common_init();
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ggml_backend_load_all();
|
||||
|
||||
common_init_result_ptr llama_init = common_init_from_params(params);
|
||||
llama_model * model = llama_init->model();
|
||||
if (model == nullptr) {
|
||||
fprintf(stderr, "%s : failed to init model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!llama_model_is_recurrent(model) && !llama_model_is_hybrid(model)) {
|
||||
fprintf(stderr, "%s : skipping for non-recurrent model\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
llama_context * ctx_src = make_ctx(params, model);
|
||||
llama_context * ctx_dst = make_ctx(params, model);
|
||||
if (ctx_src == nullptr || ctx_dst == nullptr) {
|
||||
fprintf(stderr, "%s : failed to init contexts\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (llama_n_rs_seq(ctx_src) == 0) {
|
||||
fprintf(stderr, "%s : skipping because n_rs_seq is disabled\n", __func__);
|
||||
llama_free(ctx_src);
|
||||
llama_free(ctx_dst);
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx_src, "The quick brown fox jumps", true);
|
||||
const uint32_t n_rs_seq = llama_n_rs_seq(ctx_src);
|
||||
if (tokens.size() > n_rs_seq + 1) {
|
||||
tokens.resize(n_rs_seq + 1);
|
||||
}
|
||||
if (tokens.size() < 2) {
|
||||
fprintf(stderr, "%s : not enough prompt tokens\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
const uint32_t n_tokens = tokens.size();
|
||||
const llama_token last_tok = tokens.back();
|
||||
const llama_pos last_pos = (llama_pos) n_tokens - 2;
|
||||
|
||||
// Decode the full prompt on the source, then roll back the last position.
|
||||
// Rollback leaves the recurrent memory in a snapshot state (rs_idx != 0).
|
||||
if (!decode_tokens(ctx_src, tokens, n_tokens)) {
|
||||
fprintf(stderr, "%s : failed to decode prompt\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx_src), 0, last_pos, -1)) {
|
||||
fprintf(stderr, "%s : rollback failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Save the rolled-back state and restore it into a fresh context.
|
||||
common_prompt_checkpoint ckpt;
|
||||
ckpt.update_tgt(ctx_src, 0, 0);
|
||||
ckpt.load_tgt(ctx_dst, 0, 0);
|
||||
|
||||
// Replay the rolled-back token on both contexts and compare logits.
|
||||
if (!decode_one(ctx_src, last_tok, last_pos) ||
|
||||
!decode_one(ctx_dst, last_tok, last_pos)) {
|
||||
fprintf(stderr, "%s : replay failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const float * logits_src = llama_get_logits_ith(ctx_src, 0);
|
||||
const float * logits_dst = llama_get_logits_ith(ctx_dst, 0);
|
||||
if (logits_src == nullptr || logits_dst == nullptr) {
|
||||
fprintf(stderr, "%s : missing logits\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
constexpr float eps = 1e-5f;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
if (std::fabs(logits_src[i] - logits_dst[i]) > eps) {
|
||||
fprintf(stderr, "%s : logits mismatch at token %d (%g != %g)\n",
|
||||
__func__, i, (double) logits_src[i], (double) logits_dst[i]);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Repeat the load into a context that already has its own rollback state:
|
||||
// groups 1..n_rs_seq hold a *different* prompt's history, and rs_idx[0] is
|
||||
// non-zero at load time. The restore must wipe that state and still match.
|
||||
llama_context * ctx_dirty = make_ctx(params, model);
|
||||
if (ctx_dirty == nullptr) {
|
||||
fprintf(stderr, "%s : failed to init dirty ctx\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<llama_token> noise = tokens;
|
||||
for (auto & t : noise) {
|
||||
t = (t + 1) % n_vocab;
|
||||
if (t < 0) {
|
||||
t = 0;
|
||||
}
|
||||
}
|
||||
if (!decode_tokens(ctx_dirty, noise, n_tokens)) {
|
||||
fprintf(stderr, "%s : dirty prompt decode failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx_dirty), 0, last_pos, -1)) {
|
||||
fprintf(stderr, "%s : dirty rollback failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
ckpt.load_tgt(ctx_dirty, 0, 0);
|
||||
|
||||
if (!decode_one(ctx_dirty, last_tok, last_pos)) {
|
||||
fprintf(stderr, "%s : dirty replay failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const float * logits_dirty = llama_get_logits_ith(ctx_dirty, 0);
|
||||
if (logits_dirty == nullptr) {
|
||||
fprintf(stderr, "%s : missing dirty logits\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
if (std::fabs(logits_src[i] - logits_dirty[i]) > eps) {
|
||||
fprintf(stderr, "%s : dirty-ctx logits mismatch at token %d (%g != %g)\n",
|
||||
__func__, i, (double) logits_src[i], (double) logits_dirty[i]);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : recurrent rollback checkpoint restored successfully\n", __func__);
|
||||
llama_free(ctx_src);
|
||||
llama_free(ctx_dst);
|
||||
llama_free(ctx_dirty);
|
||||
return 0;
|
||||
}
|
||||
+3
-4
@@ -55,7 +55,6 @@
|
||||
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
|
||||
| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)<br/>(env: LLAMA_ARG_RPC) |
|
||||
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
@@ -94,8 +93,8 @@
|
||||
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
|
||||
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
|
||||
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:<br/> - 0: generic output<br/> - 1: error<br/> - 2: warning<br/> - 3: info<br/> - 4: debug<br/>(default: 3)<br/><br/>(env: LLAMA_LOG_VERBOSITY) |
|
||||
| `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
|
||||
| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
|
||||
| `--log-prefix, --no-log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_ARG_LOG_PREFIX) |
|
||||
| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_ARG_LOG_TIMESTAMPS) |
|
||||
| `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) |
|
||||
| `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) |
|
||||
|
||||
@@ -199,7 +198,7 @@
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
| `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) |
|
||||
|
||||
@@ -138,7 +138,6 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
|
||||
| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)<br/>(env: LLAMA_ARG_RPC) |
|
||||
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
@@ -177,8 +176,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
|
||||
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
|
||||
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:<br/> - 0: generic output<br/> - 1: error<br/> - 2: warning<br/> - 3: info<br/> - 4: debug<br/>(default: 3)<br/><br/>(env: LLAMA_LOG_VERBOSITY) |
|
||||
| `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
|
||||
| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
|
||||
| `--log-prefix, --no-log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_ARG_LOG_PREFIX) |
|
||||
| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_ARG_LOG_TIMESTAMPS) |
|
||||
| `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) |
|
||||
| `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) |
|
||||
|
||||
|
||||
+11
-8
@@ -72,7 +72,6 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `-ctk, --cache-type-k TYPE` | KV cache data type for K<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
|
||||
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)<br/>(env: LLAMA_ARG_RPC) |
|
||||
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
@@ -111,8 +110,8 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
|
||||
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
|
||||
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:<br/> - 0: generic output<br/> - 1: error<br/> - 2: warning<br/> - 3: info<br/> - 4: debug<br/>(default: 3)<br/><br/>(env: LLAMA_LOG_VERBOSITY) |
|
||||
| `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
|
||||
| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
|
||||
| `--log-prefix, --no-log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_ARG_LOG_PREFIX) |
|
||||
| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_ARG_LOG_TIMESTAMPS) |
|
||||
| `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) |
|
||||
| `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) |
|
||||
|
||||
@@ -189,11 +188,15 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)<br/>(env: LLAMA_ARG_REUSE_PORT) |
|
||||
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
|
||||
| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
|
||||
| `--ui-config JSON` / `--webui-config JSON` (deprecated) | JSON that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG / LLAMA_ARG_WEBUI_CONFIG) |
|
||||
| `--ui-config-file PATH` / `--webui-config-file PATH` (deprecated) | JSON file that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG_FILE / LLAMA_ARG_WEBUI_CONFIG_FILE) |
|
||||
| `--ui-mcp-proxy, --no-ui-mcp-proxy` / `--webui-mcp-proxy, --no-webui-mcp-proxy` (deprecated) | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)<br/>(env: LLAMA_ARG_UI_MCP_PROXY / LLAMA_ARG_WEBUI_MCP_PROXY) |
|
||||
| `--webui-config JSON` | [DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG) |
|
||||
| `--ui-config JSON` | JSON that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG) |
|
||||
| `--webui-config-file PATH` | [DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) |
|
||||
| `--ui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG_FILE) |
|
||||
| `--webui-mcp-proxy, --no-webui-mcp-proxy` | [DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy<br/>(env: LLAMA_ARG_WEBUI_MCP_PROXY) |
|
||||
| `--ui-mcp-proxy, --no-ui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)<br/>(env: LLAMA_ARG_UI_MCP_PROXY) |
|
||||
| `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)<br/>specify "all" to enable all tools<br/>available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime<br/>(env: LLAMA_ARG_TOOLS) |
|
||||
| `--ui, --no-ui` / `--webui, --no-webui` (deprecated) | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_UI / LLAMA_ARG_WEBUI) |
|
||||
| `--webui, --no-webui` | [DEPRECATED: use --ui/--no-ui] whether to enable the Web UI<br/>(env: LLAMA_ARG_WEBUI) |
|
||||
| `--ui, --no-ui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_UI) |
|
||||
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
|
||||
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
|
||||
| `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)<br/>(env: LLAMA_API_KEY) |
|
||||
@@ -248,7 +251,7 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
|
||||
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
|
||||
| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
|
||||
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
|
||||
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
|
||||
| `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) |
|
||||
|
||||
@@ -145,9 +145,9 @@ struct server_slot {
|
||||
|
||||
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1);
|
||||
common_context_seq_rm(ctx_tgt, id, -1, -1);
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
|
||||
common_context_seq_rm(ctx_dft, id, -1, -1);
|
||||
}
|
||||
|
||||
prompt.tokens.clear();
|
||||
@@ -238,8 +238,14 @@ struct server_slot {
|
||||
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
|
||||
}
|
||||
|
||||
bool need_embd() const {
|
||||
GGML_ASSERT(task);
|
||||
return task->need_embd() || (spec && common_speculative_need_embd(spec));
|
||||
}
|
||||
|
||||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||
// also we cannot split if the pooling would require any past tokens
|
||||
// (MTP supports splitting — uses task->need_embd() not need_embd())
|
||||
bool can_split() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
@@ -511,12 +517,12 @@ struct server_slot {
|
||||
void copy_state_to(server_slot & other) const {
|
||||
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1);
|
||||
common_context_seq_rm(ctx_tgt, other.id, -1, -1);
|
||||
common_context_seq_cp(ctx_tgt, id, other.id, -1, -1);
|
||||
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1);
|
||||
llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1);
|
||||
common_context_seq_rm(ctx_dft, other.id, -1, -1);
|
||||
common_context_seq_cp(ctx_dft, id, other.id, -1, -1);
|
||||
}
|
||||
|
||||
other.n_decoded = n_decoded;
|
||||
@@ -775,10 +781,40 @@ private:
|
||||
}
|
||||
|
||||
auto cparams = common_context_params_to_llama(params_dft);
|
||||
|
||||
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
|
||||
params_base.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
|
||||
if (spec_mtp) {
|
||||
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
|
||||
}
|
||||
|
||||
// note: for small models maybe we can set this to the maximum possible draft from all speculative types
|
||||
// the extra memory for small models is likely negligible?
|
||||
cparams.n_rs_seq = 0;
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) {
|
||||
SRV_INF("creating MTP draft context against the target model '%s'\n",
|
||||
params_base.model.path.c_str());
|
||||
|
||||
auto cparams_mtp = common_context_params_to_llama(params_base);
|
||||
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
|
||||
cparams_mtp.n_rs_seq = 0;
|
||||
|
||||
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
|
||||
if (ctx_dft == nullptr) {
|
||||
SRV_ERR("%s", "failed to create MTP context\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
}
|
||||
@@ -2194,12 +2230,12 @@ private:
|
||||
|
||||
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
||||
common_context_seq_rm (ctx_tgt, slot.id, n_keep , n_keep + n_discard);
|
||||
common_context_seq_add(ctx_tgt, slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
||||
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
|
||||
common_context_seq_rm (ctx_dft.get(), slot.id, n_keep , n_keep + n_discard);
|
||||
common_context_seq_add(ctx_dft.get(), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
|
||||
}
|
||||
|
||||
// add generated tokens to cache
|
||||
@@ -2306,14 +2342,23 @@ private:
|
||||
slot.n_draft_total += draft.size();
|
||||
|
||||
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
|
||||
if (ctx_dft) {
|
||||
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1);
|
||||
if (ctx_dft) {
|
||||
if (use_ckpt_dft) {
|
||||
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
|
||||
common_context_seq_rm(ctx_dft.get(), slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
if (!draft.empty()) {
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
const bool use_ckpt_tgt =
|
||||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
|
||||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_tgt));
|
||||
|
||||
const bool use_ckpt_dft =
|
||||
(ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_dft.get()));
|
||||
|
||||
if (use_ckpt_tgt) {
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
@@ -2328,6 +2373,10 @@ private:
|
||||
(float) ckpt.size() / 1024 / 1024,
|
||||
(float) ckpt.data_dft.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
if (use_ckpt_dft) {
|
||||
ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2499,12 +2548,12 @@ private:
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
common_context_seq_rm (ctx_tgt, slot.id, head_p, head_c);
|
||||
common_context_seq_add(ctx_tgt, slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
if (ctx_dft) {
|
||||
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c);
|
||||
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
common_context_seq_rm (ctx_dft.get(), slot.id, head_p, head_c);
|
||||
common_context_seq_add(ctx_dft.get(), slot.id, head_c, head_c + n_match, kv_shift);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
@@ -2667,18 +2716,10 @@ private:
|
||||
|
||||
SLT_TRC(slot, "cached n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
||||
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
|
||||
slot.prompt_clear(true);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
} else {
|
||||
if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) {
|
||||
GGML_ABORT("failed to truncate draft context\n");
|
||||
}
|
||||
}
|
||||
common_context_seq_rm(ctx_tgt, slot.id, p0, -1);
|
||||
if (ctx_dft) {
|
||||
common_context_seq_rm(ctx_dft.get(), slot.id, p0, -1);
|
||||
}
|
||||
|
||||
// If using an alora, there may be uncached tokens that come
|
||||
// before the invocation sequence. When this happens, the
|
||||
@@ -2703,9 +2744,11 @@ private:
|
||||
// checkpoints are created only if:
|
||||
// - the model does not support partial sequence removal
|
||||
// - the model uses SWA (and we are not using `swa_full`)
|
||||
// - the model supports partial sequence removal but only up to a fixed bound
|
||||
do_checkpoint = do_checkpoint && (
|
||||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
||||
(n_swa > 0));
|
||||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
|
||||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS ||
|
||||
n_swa > 0);
|
||||
|
||||
bool has_mtmd = false;
|
||||
|
||||
@@ -2758,12 +2801,14 @@ private:
|
||||
break;
|
||||
}
|
||||
|
||||
// embedding requires all tokens in the batch to be output
|
||||
// embedding requires all tokens in the batch to be output;
|
||||
// MTP also wants logits at every prompt position so the
|
||||
// streaming hook can mirror t_h_pre_norm into ctx_dft.
|
||||
common_batch_add(batch,
|
||||
cur_tok,
|
||||
slot.prompt.tokens.pos_next(),
|
||||
{ slot.id },
|
||||
slot.task->need_embd());
|
||||
slot.need_embd());
|
||||
slot.prompt.tokens.push_back(cur_tok);
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
@@ -2877,7 +2922,7 @@ private:
|
||||
slot_batched->lora[alora_disabled_id].scale = alora_scale;
|
||||
}
|
||||
|
||||
llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd());
|
||||
llama_set_embeddings(ctx_tgt, slot_batched->need_embd());
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
@@ -3140,13 +3185,8 @@ private:
|
||||
|
||||
// verify and try to accept the draft
|
||||
{
|
||||
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
|
||||
// only save the sampler sampler state if we use checkpoints
|
||||
common_sampler_ptr smpl_save;
|
||||
if (use_ckpt_tgt) {
|
||||
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
|
||||
}
|
||||
// save the sampler sampler state in case we need to restore it
|
||||
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
|
||||
|
||||
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
|
||||
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
|
||||
@@ -3154,8 +3194,14 @@ private:
|
||||
|
||||
GGML_ASSERT(accepted.size() >= 1);
|
||||
|
||||
const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size();
|
||||
|
||||
const bool use_ckpt_tgt =
|
||||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
|
||||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt));
|
||||
|
||||
// check for partial draft acceptance
|
||||
if (accepted.size() < slot.spec_draft.size() + 1) {
|
||||
if (n_rollback > 0) {
|
||||
if (use_ckpt_tgt) {
|
||||
if (trace > 0) {
|
||||
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
|
||||
@@ -3171,13 +3217,13 @@ private:
|
||||
{
|
||||
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1);
|
||||
common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
|
||||
common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1);
|
||||
}
|
||||
|
||||
slot.prompt.tokens.keep_first(ckpt.n_tokens);
|
||||
@@ -3200,7 +3246,6 @@ private:
|
||||
|
||||
const auto ids = std::move(slot.spec_draft);
|
||||
|
||||
slot.n_decoded += ids.size();
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
@@ -3213,9 +3258,9 @@ private:
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
if (slot.ctx_dft) {
|
||||
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
@@ -3227,6 +3272,8 @@ private:
|
||||
|
||||
// TODO: set result.probs
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
|
||||
Reference in New Issue
Block a user