llama + spec: MTP Support (#22673)

* spec: support MTP

* fix batch size

* rename files

* cont : simplify (#7)

* MTP: clean-up (#9)

* MTP: clean-up

* review: use llama_context_type instead of llama_graph_type

* review: remove llama_model_has_mtp

* review: fix convert issues

* convert: fix pycheck

* review: formatting

* use `mtp-` for identifying mtp models

* convert: fix mtp conversion

* mtp -> draft-mtp

* remove unused llama_arch

* add need_embd in speculative

* llama: allow partial seq_rm for GDN models for speculative decoding

Currently speculative checkpoint needs to restart from a checkpoint
after some draft tokens are not accepted, this leads to some wastage in
running the target again. This PR adds the ability to rollback upto
`draft_max` by storing the GDN intermediates.

* fix pending state

* vulkan: add GDN partial rollback

* meta: extend check to axis 1

* metal: add GDN partial rollback

Extend the gated delta net kernel to store intermediate states for
partial rollback support on the Metal backend.

- Add K (snapshot slot count) as a function constant
- Read input state from slot 0 of the 3D state tensor
- Write intermediate states to different slots during token loop
- For K=1, maintain backward-compatible single-slot behavior

Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc

Assisted-by: llama.cpp:local pi

* delta_net_base: use ggml_pad instead of new_tensor

* review: add need_rs_seq

* review: rename part_bounded to n_rs

* review: deslop comments

* review: rename, add asserts

* server : adjust checkpoint logic (#11)

* server : adjust checkpoint logic

* cont : rm asserts

* server-context: fix early exit

* spec : fix compatibility with n-gram and add TODOs (#13)

* metal : cleanup

* llama : fix faulty bitwise check in recurrent memory

* server : disable RS-based MTP in combination with other spec types

* spec : add TODOs

* cont : fix comment

* cont : update comment

* common : fix logic for ngram + mtp compat

* llama-memory: enable checkpointing with partial rollback

* cont: add test-case for loading into a dirty ctx

* llama-memory-recurrent: clear rs_idx in clear

* download: fix mtp path

* llama-arch: fix enorm op

* docs: update docs

* conversion: fix type annotations

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Aman Gupta
2026-05-16 20:06:23 +08:00
committed by GitHub
parent b81c2cdd74
commit 255582687b
54 changed files with 2227 additions and 413 deletions
+28 -6
View File
@@ -337,11 +337,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;
bool found_mtp = false;
common_params_model mtp;
};
static handle_model_result common_params_handle_model(struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool offline,
bool search_mtp = false) {
handle_model_result result;
if (!model.docker_repo.empty()) {
@@ -356,7 +360,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, opts, true);
auto download_result = common_download_model(model, opts, true, search_mtp);
if (download_result.model_path.empty()) {
throw std::runtime_error("failed to download model from Hugging Face");
@@ -369,6 +373,11 @@ static handle_model_result common_params_handle_model(struct common_params_model
result.found_mmproj = true;
result.mmproj.path = download_result.mmproj_path;
}
if (!download_result.mtp_path.empty()) {
result.found_mtp = true;
result.mtp.path = download_result.mtp_path;
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
@@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) {
//
void common_params_handle_models(common_params & params, llama_example curr_ex) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_draft_mtp);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex)
break;
}
}
// when --spec-type mtp is set and no draft model was provided explicitly,
// fall back to the MTP head discovered alongside the -hf model
if (spec_type_draft_mtp && res.found_mtp &&
params.speculative.draft.mparams.path.empty() &&
params.speculative.draft.mparams.hf_repo.empty() &&
params.speculative.draft.mparams.url.empty()) {
params.speculative.draft.mparams.path = res.mtp.path;
}
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
}
@@ -3608,8 +3629,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("comma-separated list of types of speculative decoding to use (default: %s)\n",
common_speculative_type_name_str(params.speculative.types).c_str()),
[](common_params & params, const std::string & value) {
const auto enabled_types = string_split<std::string>(value, ',');
params.speculative.types = common_speculative_types_from_names(enabled_types);
const auto types_str = string_split<std::string>(value, ',');
auto types = common_speculative_types_from_names(types_str);
params.speculative.types.insert(params.speculative.types.end(), types.begin(), types.end());
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_TYPE"));
add_opt(common_arg(
@@ -4098,7 +4120,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--spec-default"},
string_format("enable default speculative decoding config"),
[](common_params & params) {
params.speculative.types = { COMMON_SPECULATIVE_TYPE_NGRAM_MOD };
params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
params.speculative.ngram_mod.n_match = 24;
params.speculative.ngram_mod.n_min = 48;
params.speculative.ngram_mod.n_max = 64;
+56
View File
@@ -7,6 +7,7 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "speculative.h"
#include "unicode.h"
#include <algorithm>
@@ -1247,6 +1248,29 @@ common_init_result::common_init_result(common_params & params) :
cparams.n_samplers = pimpl->samplers_seq_config.size();
}
// [TAG_RS_STATE_ROLLBACK_SUPPORT]
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
// currently this is not supported. so we disable the partial rollback
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
auto & types = params.speculative.types;
for (int i = 0; i < (int) types.size(); i++) {
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
continue;
}
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
continue;
}
cparams.n_rs_seq = 0;
LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
common_speculative_type_to_str(types[i]).c_str());
break;
}
}
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
@@ -1435,6 +1459,12 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
goto done;
}
if (llama_n_rs_seq(ctx) > 0) {
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
@@ -1449,6 +1479,23 @@ done:
return res;
}
void common_context_seq_rm(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
auto * mem = llama_get_memory(ctx);
if (!llama_memory_seq_rm(mem, seq_id, p0, p1)) {
GGML_ABORT("%s", string_format("failed to remove sequence %d with p0=%d, p1=%d\n", seq_id, p0, p1).c_str());
}
}
void common_context_seq_cp(llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
auto * mem = llama_get_memory(ctx);
llama_memory_seq_cp(mem, seq_id_src, seq_id_dst, p0, p1);
}
void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
auto * mem = llama_get_memory(ctx);
llama_memory_seq_add(mem, seq_id, p0, p1, delta);
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
std::vector<llama_adapter_lora *> loras;
std::vector<float> scales;
@@ -1505,6 +1552,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
cparams.n_rs_seq = params.speculative.need_n_rs_seq();
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.cpuparams.n_threads;
@@ -2074,3 +2122,11 @@ void common_prompt_checkpoint::load_dft(
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", data_dft.size(), n);
}
}
void common_prompt_checkpoint::clear_tgt() {
data_tgt.clear();
}
void common_prompt_checkpoint::clear_dft() {
data_dft.clear();
}
+19 -1
View File
@@ -13,6 +13,7 @@
#include <string_view>
#include <vector>
#include <map>
#include <algorithm>
#if defined(_WIN32) && !defined(_WIN32_WINNT)
#define _WIN32_WINNT 0x0A00
@@ -159,6 +160,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -301,7 +303,7 @@ struct common_params_speculative_draft {
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f
common_params_model mparams;
@@ -355,6 +357,14 @@ struct common_params_speculative {
bool has_dft() const {
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
}
uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
});
return needs_rs_seq ? draft.n_max : 0u;
}
};
struct common_params_vocoder {
@@ -887,12 +897,17 @@ enum common_context_seq_rm_type {
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
COMMON_CONTEXT_SEQ_RM_TYPE_RS = 3, // can seq_rm partial sequences, bounded by n_rs_seq
};
// check if the llama_context can remove sequences
// note: clears the memory of the context
common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx);
// aborts execution on failure
void common_context_seq_rm (llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
void common_context_seq_add(llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
void common_context_seq_cp (llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
//
// Batch utils
@@ -1074,4 +1089,7 @@ struct common_prompt_checkpoint {
llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags) const;
void clear_tgt();
void clear_dft();
};
+42 -13
View File
@@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
return result;
}
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"),
// preferring deeper shared directory prefix with the model, then closest quantization
static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files,
const std::string & model,
const std::string & keyword) {
hf_cache::hf_file best;
size_t best_depth = 0;
int best_diff = 0;
@@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
for (const auto & f : files) {
if (!string_ends_with(f.path, ".gguf") ||
f.path.find("mmproj") == std::string::npos) {
f.path.find(keyword) == std::string::npos) {
continue;
}
auto mmproj_parts = string_split<std::string>(f.path, '/');
auto mmproj_dir = mmproj_parts.end() - 1;
auto sib_parts = string_split<std::string>(f.path, '/');
auto sib_dir = sib_parts.end() - 1;
auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
mmproj_parts.begin(), mmproj_dir);
if (dir != mmproj_dir) {
sib_parts.begin(), sib_dir);
if (dir != sib_dir) {
continue;
}
size_t depth = dir - mmproj_parts.begin();
size_t depth = dir - sib_parts.begin();
auto bits = extract_quant_bits(f.path);
auto diff = std::abs(bits - model_bits);
@@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
return best;
}
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "mmproj");
}
static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files,
const std::string & model) {
return find_best_sibling(files, model, "mtp-");
}
static bool gguf_filename_is_model(const std::string & filepath) {
if (!string_ends_with(filepath, ".gguf")) {
return false;
@@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) {
}
return filename.find("mmproj") == std::string::npos &&
filename.find("imatrix") == std::string::npos;
filename.find("imatrix") == std::string::npos &&
filename.find("mtp-") == std::string::npos;
}
static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
@@ -673,11 +687,13 @@ struct hf_plan {
hf_cache::hf_file primary;
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
};
static hf_plan get_hf_plan(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
hf_plan plan;
hf_cache::hf_files all;
@@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.mmproj = find_best_mmproj(all, primary.path);
}
if (download_mtp) {
plan.mtp = find_best_mtp(all, primary.path);
}
return plan;
}
@@ -756,7 +776,8 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
common_download_model_result common_download_model(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
bool download_mmproj,
bool download_mtp) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;
@@ -764,13 +785,16 @@ common_download_model_result common_download_model(const common_params_model &
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj);
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
} else {
@@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
} else {
result.model_path = model.path;
}
@@ -946,7 +974,8 @@ std::vector<common_cached_model_info> common_list_cached_models() {
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.index != 1 || split.tag.empty() ||
split.prefix.find("mmproj") != std::string::npos) {
split.prefix.find("mmproj") != std::string::npos ||
split.prefix.find("mtp-") != std::string::npos) {
continue;
}
if (seen.insert(f.repo_id + ":" + split.tag).second) {
+5 -2
View File
@@ -59,6 +59,7 @@ struct common_download_opts {
struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
std::string mtp_path;
};
// Download model from HuggingFace repo or URL
@@ -83,12 +84,14 @@ struct common_download_model_result {
// when opts.offline=true, no network requests are made
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
// then with the closest quantization bits
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
//
// returns result with model_path and mmproj_path (empty on failure)
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
common_download_model_result common_download_model(
const common_params_model & model,
const common_download_opts & opts = {},
bool download_mmproj = false
bool download_mmproj = false,
bool download_mtp = false
);
// returns list of cached models
+377 -8
View File
@@ -3,6 +3,7 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
@@ -23,6 +24,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -143,6 +145,9 @@ struct common_speculative_impl {
virtual void draft(common_speculative_draft_params_vec & dparams) = 0;
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
// true if this implementation requires the target context to extract embeddings
virtual bool need_embd() const = 0;
};
struct common_speculative_impl_draft_simple : public common_speculative_impl {
@@ -338,6 +343,10 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
@@ -362,6 +371,328 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_state_draft_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
llama_batch batch;
std::vector<common_sampler_ptr> smpls;
int32_t n_embd = 0;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
std::vector<std::vector<float>> pending_h; // [n_seq][n_embd]
std::vector<int32_t> i_batch_beg;
std::vector<int32_t> i_batch_end;
// Hidden rows from the most recent target verification batch, grouped by seq.
// Row 0 corresponds to the sampled token, row N to the Nth accepted draft token.
std::vector<std::vector<float>> verify_h;
std::vector<int32_t> verify_h_rows;
// Per-seq draft length from the last draft() call, used in accept() to
// roll back ctx_dft's recurrent state past the AR draft's redundant
// pre-advancement before process() mirrored the verify batch.
std::vector<uint16_t> last_n_drafted;
common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
, params(params.draft)
{
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
// llama_batch_init allocates only one of token/embd; MTP needs both.
// TODO: fix, how to call without malloc
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b);
smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
llama_set_embeddings_pre_norm(ctx_tgt, true);
llama_set_embeddings_pre_norm(ctx_dft, true);
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
verify_h.assign(n_seq, {});
verify_h_rows.assign(n_seq, 0);
last_n_drafted.assign(n_seq, 0);
}
~common_speculative_state_draft_mtp() override {
if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;
}
llama_batch_free(batch);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
}
}
bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}
// TODO: how to make it work with vision tokens?
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
return true;
}
const int32_t n_tokens = batch_in.n_tokens;
// remember the frist and last batch index for each sequence
std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1);
std::fill(i_batch_end.begin(), i_batch_end.end(), -1);
for (int k = 0; k < n_tokens; ++k) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
if (batch_in.seq_id[k][0] == seq_id) {
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
}
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_batch_clear(batch);
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_end[seq_id] < 0) {
continue;
}
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
verify_h_rows[seq_id] = n_rows;
verify_h[seq_id].resize((size_t) n_rows * n_embd);
for (int32_t i = 0; i < n_rows; ++i) {
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i);
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
}
std::memcpy(pending_h[seq_id].data(),
verify_h[seq_id].data() + (size_t) (n_rows - 1) * n_embd, row_bytes);
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
common_batch_clear(batch);
// keep track of which sequences are still drafting
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const float * h_row = nullptr;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
n_drafting++;
drafting[seq_id] = true;
common_sampler_reset(smpls[seq_id].get());
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (!drafting[seq_id]) {
continue;
}
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
++i_batch;
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
common_sampler_accept(smpl, id, true);
auto & dp = dparams.at(seq_id);
auto & result = *dp.result;
result.push_back(id);
if (params.n_max <= (int) result.size()) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
if (batch.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
const int32_t n_rows = verify_h_rows[seq_id];
if (n_rows <= 0) {
return;
}
const int32_t i_h = std::min<int32_t>(n_accepted, n_rows - 1);
const size_t row_bytes = (size_t) n_embd * sizeof(float);
std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) i_h * n_embd, row_bytes);
}
bool need_embd() const override {
return true;
}
};
// state of self-speculation (simple implementation, not ngram-map)
@@ -403,6 +734,10 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
@@ -451,6 +786,10 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
common_ngram_map_accept(config[seq_id], n_accepted);
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_ngram_mod : public common_speculative_impl {
@@ -619,6 +958,10 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
}
}
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_ngram_cache : public common_speculative_impl {
@@ -752,6 +1095,10 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative {
@@ -820,6 +1167,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
@@ -875,8 +1223,8 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_model_path = !params.draft.mparams.path.empty();
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
// bool has_mtp = false; // TODO: add MTP here
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE));
bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE));
@@ -885,7 +1233,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
// when adding a new type - update here the logic above
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 8);
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
// this list here defines the priority of the speculators
// the one with highest priority are listed first
@@ -911,7 +1259,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
has_draft_simple = false;
}
} else if (has_draft_model_path) {
} else if (has_draft_model_path && !has_mtp && !has_draft_eagle3) {
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
has_draft_simple = true;
}
@@ -919,10 +1267,12 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_simple) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params));
}
// TODO: add MTP here
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
if (has_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
}
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
@@ -940,6 +1290,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
impls.push_back(std::make_unique<common_speculative_impl_draft_eagle3>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: {
impls.push_back(std::make_unique<common_speculative_state_draft_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
@@ -1040,6 +1394,20 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b
return result;
}
bool common_speculative_need_embd(common_speculative * spec) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->need_embd()) {
return true;
}
}
return false;
}
void common_speculative_draft(common_speculative * spec) {
if (spec == nullptr) {
return;
@@ -1122,14 +1490,15 @@ void common_speculative_draft(common_speculative * spec) {
}
void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) {
if (n_accepted == 0) {
return;
}
common_speculative_impl * impl = spec->impl_last[seq_id];
GGML_ASSERT(impl);
// TODO: currently only the implementation that generated the draft is used to accept it
// however, some implementations (such as MTP) need to also "see" the accepted tokens
// extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to
// inform the implementation if the accepted tokens are from another implementation and
// pass the accepted tokens to all remaining implementations using `is_other == true`
{
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
if (n_accepted > 0) {
+3
View File
@@ -53,6 +53,9 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
// process the batch and update the internal state of the speculative context
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
// true if any implementation requires target embeddings to be extracted
bool common_speculative_need_embd(common_speculative * spec);
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);
+6
View File
@@ -91,6 +91,7 @@ class ModelBase:
gguf_writer: gguf.GGUFWriter
model_name: str | None
metadata_override: Path | None
metadata: gguf.Metadata
dir_model_card: Path
remote_hf_model_id: str | None
@@ -106,6 +107,11 @@ class ModelBase:
disable_mistral_community_chat_template: bool = False
sentence_transformers_dense_modules: bool = False
# MTP (multi-token prediction) export modes; set by main() before instantiation.
# Architectures opt in by overriding the handling (see _Qwen35MtpMixin).
mtp_only: bool = False
no_mtp: bool = False
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
+86 -3
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Callable, Iterable, TYPE_CHECKING
from pathlib import Path
from typing import Any, Callable, Iterable, TYPE_CHECKING
import torch
@@ -534,11 +535,93 @@ class _Qwen35MRopeMixin:
self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION)
class _Qwen35MtpMixin:
"""Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries
the MTP block under `mtp_num_hidden_layers` and the tensors under
`mtp.*`; we extend block_count, emit the nextn metadata key, and remap
`mtp.*` to the standard layer-indexed nextn naming so the existing
tensor_map handles them."""
hparams: dict[str, Any]
model_arch: gguf.MODEL_ARCH
gguf_writer: gguf.GGUFWriter
block_count: int
tensor_map: gguf.TensorNameMap
no_mtp: bool
mtp_only: bool
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"]
if not self.no_mtp:
self.block_count += self.hparams.get("mtp_num_hidden_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@classmethod
def filter_tensors(cls, item):
name, _ = item
if name.startswith("mtp."):
if cls.no_mtp:
return None
return item
if cls.mtp_only:
canonical = name.replace("language_model.", "")
keep = canonical in (
"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight",
"embed_tokens.weight", "norm.weight",
)
if not keep:
return None
return super().filter_tensors(item) # ty: ignore[unresolved-attribute]
def set_gguf_parameters(self):
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
if self.no_mtp:
return
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
self.gguf_writer.add_nextn_predict_layers(n)
def prepare_metadata(self, vocab_only: bool):
from_dir = self.fname_out.is_dir()
super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute]
if not self.mtp_only or not from_dir:
return
output_type: str = self.ftype.name.partition("_")[2] # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
fname_default: str = gguf.naming_convention(
self.metadata.name, self.metadata.basename, self.metadata.finetune, # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf"
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp."):
n_layer = self.hparams["num_hidden_layers"]
if name.find("layers.") != -1:
assert bid is not None
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
else:
remapper = {
"mtp.fc": "model.layers.{bid}.eh_proj",
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
"mtp.norm": "model.layers.{bid}.shared_head.norm",
}
stem = Path(name).stem
suffix = Path(name).suffix
tmpl = remapper[stem] + suffix
for b in range(n_layer, self.block_count):
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute]
return
yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute]
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE
+22
View File
@@ -117,6 +117,14 @@ def parse_args() -> argparse.Namespace:
"--mmproj", action="store_true",
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
)
parser.add_argument(
"--mtp", action="store_true",
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
)
parser.add_argument(
"--no-mtp", action="store_true",
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
)
parser.add_argument(
"--mistral-format", action="store_true",
help="Whether the model is stored following the Mistral format.",
@@ -233,6 +241,20 @@ def main() -> None:
from conversion.mistral import MistralModel
model_class = MistralModel
if args.mtp and args.no_mtp:
logger.error("--mtp and --no-mtp are mutually exclusive")
sys.exit(1)
if args.mtp or args.no_mtp:
from conversion.qwen import _Qwen35MtpMixin
if not issubclass(model_class, _Qwen35MtpMixin):
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today")
sys.exit(1)
if args.no_mtp:
model_class.no_mtp = True
if args.mtp:
model_class.mtp_only = True
model_instance = model_class(dir_model, output_type, fname_out,
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
eager=args.no_lazy,
+5
View File
@@ -2541,6 +2541,11 @@ extern "C" {
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
//
// state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs):
// K == 1: output carries the final state only.
// K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K)
// per-token snapshots into the trailing slots
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
+3 -2
View File
@@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co
GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2);
// state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0,
// so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2).
GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1};
};
@@ -2140,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz
const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context;
return backend_ctx->backend_configs[index].backend;
}
+3 -1
View File
@@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan(
case GGML_OP_GATED_DELTA_NET:
{
const int64_t S_v = node->src[2]->ne[0];
cur = S_v * sizeof(float) * n_tasks;
const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs)
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
cur = per_thread * sizeof(float) * n_tasks;
} break;
case GGML_OP_COUNT:
{
+32 -7
View File
@@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const bool kda = (neg0 == S_v);
// scratch layout per thread: [delta(S_v)]
const int64_t scratch_per_thread = S_v;
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int64_t K = src_state->ne[1];
GGML_ASSERT(K >= 1);
// per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride)
const int64_t state_seq_stride = src_state->nb[2] / sizeof(float);
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
const int ith = params->ith;
float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32;
float * state_work = K > 1 ? (delta + S_v) : nullptr;
// output layout: [attn_scores | new_states]
// attn_scores: S_v * H * n_tokens * n_seqs floats
// new_states: S_v * S_v * H * n_seqs floats
// new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K))
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
const int64_t state_size_per_snap = S_v * S_v * H * n_seqs;
float * attn_out_base = (float *)dst->data;
float * state_out_base = (float *)dst->data + attn_score_elems;
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
const int64_t shift = n_tokens - K;
const float * state_in_base = (const float *)src_state->data;
//const int64_t rq1 = nev1 / neq1;
@@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
// For K=1, write directly to the single output slot to avoid an extra memcpy at the end.
// For K>1, work in scratch and copy out per-token when the slot is in range.
float * s_out = (K > 1)
? state_work
: state_out_base + (iv3 * H + iv1) * S_v * S_v;
// copy input state into output buffer and operate in-place
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
// copy input state into the working buffer and operate in-place
// state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride.
const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v;
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
// attn output pointer for first token of this (head, seq)
@@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
}
attn_data += S_v * H; // advance to next token
if (K > 1) {
const int64_t target_slot = t - shift;
if (target_slot >= 0 && target_slot < K) {
float * curr_state_o = state_out_base + target_slot * state_size_per_snap +
(iv3 * H + iv1) * S_v * S_v;
memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float));
}
}
}
}
}
+58 -20
View File
@@ -1,6 +1,6 @@
#include "gated_delta_net.cuh"
template <int S_v, bool KDA>
template <int S_v, bool KDA, bool keep_rs_t>
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
gated_delta_net_cuda(const float * q,
const float * k,
@@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q,
int64_t sb3,
const uint3 neqk1_magic,
const uint3 rq3_magic,
float scale) {
float scale,
int K) {
const uint32_t h_idx = blockIdx.x;
const uint32_t sequence = blockIdx.y;
// each warp owns one column, using warp-level primitives to reduce across rows
@@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q,
float * attn_data = dst;
float * state = dst + attn_score_elems;
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset + col * S_v;
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
state += state_out_offset;
curr_state += state_in_offset + col * S_v;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
@@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q,
s_shard[r] = curr_state[i];
}
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
// are written; earlier slots are left untouched (caller-owned).
const int shift = (int) n_tokens - K;
for (int t = 0; t < n_tokens; t++) {
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
@@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q,
}
attn_data += S_v * H;
if constexpr (keep_rs_t) {
const int target_slot = t - shift;
if (target_slot >= 0 && target_slot < K) {
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
curr_state[col * S_v + i] = s_shard[r];
}
}
}
}
// Write state back to global memory (transposed layout)
if constexpr (!keep_rs_t) {
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[col * S_v + i] = s_shard[r];
}
}
}
template <bool KDA>
template <bool KDA, bool keep_rs_t>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
@@ -155,7 +177,7 @@ static void launch_gated_delta_net(
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
float scale, int K, cudaStream_t stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const int num_warps = 4;
@@ -169,29 +191,29 @@ static void launch_gated_delta_net(
switch (S_v) {
case 16:
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
gated_delta_net_cuda<16, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
break;
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
gated_delta_net_cuda<32, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
break;
case 64: {
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
gated_delta_net_cuda<64, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
break;
}
case 128: {
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
gated_delta_net_cuda<128, KDA, keep_rs_t><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
break;
}
default:
@@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
cudaStream_t stream = ctx.stream();
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int K = (int) src_state->ne[1];
const bool keep_rs = K > 1;
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
if (keep_rs) {
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
}
} else {
if (keep_rs) {
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
} else {
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
}
}
}
+4 -1
View File
@@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
const int ne20 = op->src[2]->ne[0]; // S_v
const int ne21 = op->src[2]->ne[1]; // H
const int ne30 = op->src[3]->ne[0]; // G
// state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int K = op->src[5]->ne[1];
const int nsg = op->src[2]->ne[0]/32;
@@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
GGML_ASSERT(ne20 % 32 == 0);
snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
@@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+35 -7
View File
@@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32(
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
#if 1
template<short NSG>
@@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl(
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
#define K FC_gated_delta_net_K
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i23 = tgpig.z; // B (n_seqs)
const uint i21 = tgpig.y; // H (head)
const uint i20 = tgpig.x*NSG + ty; // row within S_v
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
@@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl(
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
const int shift = (int)args.ne22 - (int)K;
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
// output state per-slot size: S_v * S_v * H * n_seqs
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
// per-(seq,head) offset within a slot
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
@@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl(
b_ptr += args.ne21;
g_ptr += args.ne21*G;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
if (K > 1u) {
const int target_slot = (int)t - shift;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
}
}
if (K == 1u) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
#undef S_v
#undef G
#undef K
}
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
+7 -1
View File
@@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants {
uint32_t sb1, sb2, sb3;
uint32_t neq1, rq3;
float scale;
uint32_t K;
};
struct vk_op_ssm_scan_push_constants {
@@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const ggml_tensor * src_q = dst->src[0];
const ggml_tensor * src_v = dst->src[2];
const ggml_tensor * src_beta = dst->src[4];
const ggml_tensor * src_state = dst->src[5];
GGML_ASSERT(dst->buffer != nullptr);
@@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const uint32_t K = (uint32_t)src_state->ne[1];
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
@@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
sv1, sv2, sv3,
sb1, sb2, sb3,
neq1, rq3,
scale
scale,
K
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
@@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters {
uint sb1, sb2, sb3;
uint neq1, rq3;
float scale;
uint K;
};
layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
@@ -101,13 +102,21 @@ void main() {
const uint iq3 = seq_id / rq3;
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
const uint state_in_base = (seq_id * K * H + head_id) * state_size;
// output state layout per slot: same per-(seq,head) offset as the single-slot case.
const uint state_out_base = (seq_id * H + head_id) * state_size;
const uint state_size_per_snap = state_size * H * n_seqs;
FLOAT_TYPE s_shard[ROWS_PER_LANE];
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]);
s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]);
}
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
const int shift = int(n_tokens) - int(K);
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
for (uint t = 0; t < n_tokens; t++) {
@@ -161,9 +170,21 @@ void main() {
}
attn_off += S_V * H;
if (K > 1u) {
const int target_slot = int(t) - shift;
if (target_slot >= 0 && target_slot < int(K)) {
const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
}
}
}
}
if (K == 1u) {
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r];
}
}
}
+7 -5
View File
@@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net(
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
GGML_ASSERT(beta->ne[0] == 1);
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
// concat output and new_state into a single tensor
// output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
// state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count.
GGML_ASSERT(state->ne[0] == S_v * S_v * H);
GGML_ASSERT(state->ne[2] == n_seqs);
GGML_ASSERT(state->ne[3] == 1);
const int64_t K = state->ne[1];
const int64_t state_rows = K * S_v * n_seqs;
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_GATED_DELTA_NET;
+16 -2
View File
@@ -2114,7 +2114,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
MODEL_TENSOR.SSM_OUT,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.QWEN35MOE: [
MODEL_TENSOR.TOKEN_EMBD,
@@ -2145,7 +2152,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
MODEL_TENSOR.SSM_OUT,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,
+8
View File
@@ -198,6 +198,11 @@ extern "C" {
LLAMA_SPLIT_MODE_TENSOR = 3,
};
enum llama_context_type {
LLAMA_CONTEXT_TYPE_DEFAULT = 0,
LLAMA_CONTEXT_TYPE_MTP = 1,
};
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data {
llama_token id; // token id
@@ -333,9 +338,11 @@ extern "C" {
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
uint32_t n_ubatch; // physical maximum batch size
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
enum llama_context_type ctx_type; // set the context type (e.g. MTP)
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // attention type to use for embeddings
@@ -530,6 +537,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
+19 -8
View File
@@ -757,14 +757,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
// the model loader doesn't fault on the block index.
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// Nemotron 3 Super
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
}
}
bool llm_arch_supports_rs_rollback(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_QWEN35:
case LLM_ARCH_QWEN35MOE:
return true;
default:
return false;
}
}
bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_GROK:
+1
View File
@@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
bool llm_arch_is_diffusion (const llm_arch & arch);
bool llm_arch_supports_sm_tensor(const llm_arch & arch);
bool llm_arch_supports_rs_rollback(const llm_arch & arch);
+133 -6
View File
@@ -2,6 +2,7 @@
#include "ggml.h"
#include "llama-arch.h"
#include "llama-graph.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
@@ -21,6 +22,14 @@
// llama_context
//
static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) {
switch (ctx_type) {
case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT;
case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP;
}
throw std::runtime_error("Unsupported ctx type");
}
llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
@@ -42,6 +51,13 @@ llama_context::llama_context(
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
}
cparams.n_rs_seq = params.n_rs_seq;
if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) {
LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n",
__func__, cparams.n_rs_seq);
cparams.n_rs_seq = 0;
}
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
@@ -49,6 +65,7 @@ llama_context::llama_context(
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.embeddings_pre_norm = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
@@ -65,6 +82,8 @@ llama_context::llama_context(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.ctx_type = params.ctx_type;
// Initialize backend samplers here so they are part of the sampling graph
// before the reserve passes run later in this function. This avoids a later
// re-reserve when graph nodes change.
@@ -206,6 +225,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
@@ -278,6 +298,7 @@ llama_context::llama_context(
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type= */ cparams.ctx_type,
};
memory.reset(model.create_memory(params_mem, cparams));
@@ -860,6 +881,33 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}
float * llama_context::get_embeddings_pre_norm() {
output_reorder();
return embd_pre_norm.data;
}
float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
output_reorder();
try {
if (embd_pre_norm.data == nullptr) {
throw std::runtime_error("no pre-norm embeddings");
}
const int64_t j = output_resolve_row(i);
const uint32_t n_embd = model.hparams.n_embd;
return embd_pre_norm.data + j*n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
#else
return nullptr;
#endif
}
}
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();
@@ -1040,6 +1088,12 @@ void llama_context::set_embeddings(bool value) {
//sched_need_reserve = true;
}
void llama_context::set_embeddings_pre_norm(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
cparams.embeddings_pre_norm = value;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
@@ -1241,7 +1295,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
}
int llama_context::encode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -1314,6 +1370,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
// extract logits
if (logits.data && t_logits) {
@@ -1379,6 +1436,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
// extract pre-norm embeddings (hidden state before the final output norm)
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
}
// TODO: hacky solution
if (model.arch == LLM_ARCH_T5 && t_embd) {
//cross.t_embd = t_embd;
@@ -1531,7 +1598,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
}
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
@@ -1689,7 +1758,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
@@ -1729,6 +1799,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
@@ -1809,6 +1880,20 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}
// extract pre-norm embeddings (hidden state before the final output norm)
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
}
// Copy backend sampling output if this ubatch produced any sampling tensors.
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
@@ -1893,10 +1978,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();
bool has_logits = true;
bool has_embd = cparams.embeddings;
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
@@ -1910,6 +1997,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
@@ -1925,7 +2013,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size =
(logits.size + embd.size + backend_float_count) * sizeof(float) +
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);
// alloc only when more than the current capacity is required
@@ -1942,6 +2030,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
buf_output = nullptr;
logits.data = nullptr;
embd.data = nullptr;
embd_pre_norm.data = nullptr;
}
auto * buft = ggml_backend_cpu_buffer_type();
@@ -1970,6 +2059,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
offset += embd_pre_norm.size * sizeof(float);
if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float);
@@ -2034,6 +2126,12 @@ void llama_context::output_reorder() {
}
}
if (embd_pre_norm.size > 0) {
for (uint64_t k = 0; k < n_embd; k++) {
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
}
}
if (!sampling.samplers.empty()) {
assert(sampling.logits.size > 0);
assert(sampling.probs.size > 0);
@@ -2121,7 +2219,7 @@ ggml_cgraph * llama_context::graph_reserve(
auto * res = gf_res_reserve.get();
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type));
res->reset();
@@ -3100,7 +3198,7 @@ void llama_context::opt_epoch_iter(
auto * res = gf_res_prev.get();
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type));
res->reset();
@@ -3201,8 +3299,10 @@ llama_context_params llama_context_default_params() {
/*.n_batch =*/ 2048,
/*.n_ubatch =*/ 512,
/*.n_seq_max =*/ 1,
/*.n_rs_seq =*/ 0,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
@@ -3306,6 +3406,13 @@ llama_context * llama_init_from_model(
model->hparams.pooling_type, params.pooling_type);
}
if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP &&
model->hparams.nextn_predict_layers == 0) {
LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__);
return nullptr;
}
try {
auto * ctx = new llama_context(*model, params);
return ctx;
@@ -3347,6 +3454,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) {
return ctx->n_seq_max();
}
uint32_t llama_n_rs_seq(const llama_context * ctx) {
return ctx->get_cparams().n_rs_seq;
}
const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model();
}
@@ -3436,6 +3547,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value) {
ctx->set_embeddings_pre_norm(value);
}
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
ctx->synchronize();
return ctx->get_embeddings_pre_norm();
}
float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();
return ctx->get_embeddings_pre_norm_ith(i);
}
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
return ctx->set_sampler(seq_id, smpl);
}
+9
View File
@@ -84,6 +84,9 @@ struct llama_context {
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
float * get_embeddings_pre_norm();
float * get_embeddings_pre_norm_ith(int32_t i);
llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
@@ -107,6 +110,7 @@ struct llama_context {
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_embeddings_pre_norm(bool value);
void set_causal_attn(bool value);
void set_warmup(bool value);
@@ -278,6 +282,11 @@ private:
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
buffer_view<float> embd = {nullptr, 0};
// hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd])
// populated only when cparams.embeddings_pre_norm is enabled and the model graph
// sets llm_graph_result::t_h_pre_norm
buffer_view<float> embd_pre_norm = {nullptr, 0};
struct sampling_info {
// !samplers.empty() to check if any samplers are active
std::map<llama_seq_id, llama_sampler *> samplers;
+3
View File
@@ -12,6 +12,7 @@ struct llama_cparams {
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
@@ -27,6 +28,7 @@ struct llama_cparams {
float yarn_beta_slow;
bool embeddings;
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
bool causal_attn;
bool offload_kqv;
bool flash_attn;
@@ -40,6 +42,7 @@ struct llama_cparams {
bool kv_unified;
bool pipeline_parallel;
enum llama_context_type ctx_type;
enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval;
+16
View File
@@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model);
LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i);
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
//
// pre-norm embeddings (hidden state before the final output norm)
//
// mirrors:
// LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_pre_norm(struct llama_context * ctx);
// mirrors:
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);
+2 -1
View File
@@ -2528,7 +2528,8 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t rs_zero,
const llm_graph_get_rows_fn & get_state_rows) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
GGML_UNUSED(rs_size);
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]);
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
+3
View File
@@ -32,6 +32,7 @@ enum llm_graph_type {
LLM_GRAPH_TYPE_DEFAULT,
LLM_GRAPH_TYPE_ENCODER,
LLM_GRAPH_TYPE_DECODER,
LLM_GRAPH_TYPE_DECODER_MTP,
};
enum llm_ffn_op_type {
@@ -644,6 +645,7 @@ public:
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; }
ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
@@ -672,6 +674,7 @@ public:
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;
+6
View File
@@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const {
}
bool llama_hparams::has_kv(uint32_t il) const {
if (kv_only_nextn) {
// MTP head: only the trailing nextn_predict_layers blocks own a KV cache;
// the leading trunk blocks are not executed in this graph.
return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers);
}
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;
+2
View File
@@ -92,6 +92,8 @@ struct llama_hparams {
uint32_t moe_latent_size = 0;
uint32_t nextn_predict_layers = 0;
bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches)
float f_norm_eps;
float f_norm_rms_eps;
float f_norm_group_eps;
+2
View File
@@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
@@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
offload,
rs_size,
n_seq_max,
n_rs_seq,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
+1
View File
@@ -34,6 +34,7 @@ public:
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
+2
View File
@@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid(
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
@@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid(
offload,
rs_size,
n_seq_max,
n_rs_seq,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
+1
View File
@@ -34,6 +34,7 @@ public:
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
+104 -15
View File
@@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent(
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
uint32_t n_rs_seq,
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
@@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent(
size = mem_size;
used = 0;
this->n_rs_seq = n_rs_seq;
rs_idx.assign(n_seq_max, 0);
cells.clear();
cells.resize(mem_size);
@@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent(
throw std::runtime_error("failed to create ggml context for rs cache");
}
ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size);
ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size);
const uint32_t n_rows = mem_size * (1 + n_rs_seq);
ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows);
ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows);
ggml_format_name(r, "cache_r_l%d", i);
ggml_format_name(s, "cache_s_l%d", i);
r_l[i] = r;
@@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent(
const size_t memory_size_r = size_r_bytes();
const size_t memory_size_s = size_s_bytes();
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq,
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
}
@@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) {
ggml_backend_buffer_clear(buf.get(), 0);
}
}
std::fill(rs_idx.begin(), rs_idx.end(), 0);
}
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
uint32_t new_head = size;
if (p0 < 0) {
@@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
p1 = std::numeric_limits<llama_pos>::max();
}
const bool rm_all = p0 == 0 && p1 == std::numeric_limits<llama_pos>::max();
if (rm_all) {
if (seq_id >= 0) {
set_rs_idx(seq_id, 0);
} else {
std::fill(rs_idx.begin(), rs_idx.end(), 0);
}
}
// models like Mamba or RWKV can't have a state partially erased at the end
// of the sequence because their state isn't preserved for previous tokens
if (seq_id >= (int64_t) size) {
@@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (0 <= seq_id) {
int32_t & tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid if it includes the final pos
auto & cell = cells[tail_id];
// partial rollback via per-token snapshot index (bounded by n_rs_seq)
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
const llama_pos rollback = cell.pos - (p0 - 1);
if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) {
set_rs_idx(seq_id, (uint32_t) rollback);
cell.pos = p0 - 1;
return true;
}
return false;
}
// invalidate tails which will be cleared
@@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) {
if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) {
return;
}
rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx;
}
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const auto & [_, buf] : ctxs_bufs) {
@@ -703,6 +731,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
GGML_UNUSED(flags);
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges_data; // logical source row ranges
uint32_t cell_count = 0;
// Count the number of cells with the specified seq_id
@@ -712,6 +741,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
const auto & cell = cells[i];
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
++cell_count;
uint32_t rs_idx_cur = 0;
if (n_rs_seq != 0) {
if (seq_id != -1) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size());
rs_idx_cur = rs_idx[seq_id];
} else {
bool has_rs_idx = false;
for (const llama_seq_id cell_seq_id : cell.seq_id) {
GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size());
const uint32_t seq_rs_idx = rs_idx[cell_seq_id];
if (!has_rs_idx) {
rs_idx_cur = seq_rs_idx;
has_rs_idx = true;
} else if (rs_idx_cur != seq_rs_idx) {
GGML_ABORT("cannot write shared recurrent state with different rollback indices");
}
}
}
}
const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i);
if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) {
cell_ranges_data.emplace_back(cell_id, cell_id + 1);
} else {
cell_ranges_data.back().second++;
}
if (cell_range_begin == size) {
cell_range_begin = i;
}
@@ -726,7 +784,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
cell_ranges.emplace_back(cell_range_begin, size);
}
if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) {
GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n");
}
@@ -737,10 +795,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
}
GGML_ASSERT(cell_count == cell_count_check);
cell_count_check = 0;
for (const auto & range : cell_ranges_data) {
cell_count_check += range.second - range.first;
}
GGML_ASSERT(cell_count == cell_count_check);
io.write(&cell_count, sizeof(cell_count));
state_write_meta(io, cell_ranges, seq_id);
state_write_data(io, cell_ranges);
state_write_data(io, cell_ranges_data);
}
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
@@ -762,6 +826,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i
}
throw std::runtime_error("failed to restore kv cache");
}
if (n_rs_seq != 0) {
if (seq_id == -1) {
std::fill(rs_idx.begin(), rs_idx.end(), 0);
} else {
set_rs_idx(seq_id, 0);
}
}
}
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
@@ -804,7 +876,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
io.write(&r_size_row, sizeof(r_size_row));
// Write each range of cells of r_size_row length
// Write each logical cell row range. With pending recurrent rollback,
// the logical current state may live in a rollback snapshot plane.
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * r_size_row;
@@ -825,7 +898,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
io.write(&s_size_row, sizeof(s_size_row));
// Write each range of S tensor rows
// Write each logical cell row range. With pending recurrent rollback,
// the logical current state may live in a rollback snapshot plane.
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * s_size_row;
@@ -852,9 +926,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
// Write GQA embedding size
io.write(&n_embd_s, sizeof(n_embd_s));
// For each row, we get the element values of each cell
// For each row, we get the element values of each logical cell
for (uint32_t j = 0; j < n_embd_s; ++j) {
// Write each range of cells of s_size_el length
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
@@ -1163,5 +1236,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
}
int32_t llama_memory_recurrent_context::s_copy(int i) const {
return mem->cells[i + mem->head].src0;
const uint32_t cell_idx = i + mem->head;
const int32_t src0 = mem->cells[cell_idx].src0;
if (mem->n_rs_seq == 0) {
return src0;
}
uint32_t idx = 0;
if (!mem->cells[cell_idx].seq_id.empty()) {
const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin();
if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) {
idx = mem->rs_idx[seq];
// reset rollback idx
mem->rs_idx[seq] = 0;
}
}
return (int32_t)(idx * mem->size) + src0;
}
+8
View File
@@ -23,6 +23,7 @@ public:
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
uint32_t n_rs_seq,
const layer_filter_cb & filter);
~llama_memory_recurrent() = default;
@@ -69,6 +70,13 @@ public:
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups
uint32_t n_rs_seq = 0;
// per-seq rollback index
std::vector<uint32_t> rs_idx;
void set_rs_idx(llama_seq_id seq_id, uint32_t idx);
// computed before each graph build
uint32_t n = 0;
+3
View File
@@ -1,6 +1,7 @@
#pragma once
#include "llama.h"
#include "llama-graph.h"
#include <map>
#include <memory>
@@ -20,6 +21,8 @@ struct llama_memory_params {
// use full-size SWA cache
bool swa_full;
llama_context_type ctx_type;
};
enum llama_memory_status {
+9 -2
View File
@@ -1312,10 +1312,17 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
return tensor;
}
void llama_model_loader::done_getting_tensors() const {
if (n_created != n_tensors) {
void llama_model_loader::done_getting_tensors(bool partial) const {
if (n_created > n_tensors) {
throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created));
}
if (n_created < n_tensors) {
if (!partial) {
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
}
LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n",
__func__, n_created, n_tensors);
}
if (n_tensors_moved > 0) {
LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n",
__func__, first_tensor_moved_name.c_str(), first_tensor_moved_type_name.c_str(), n_tensors_moved - 1,
+1 -1
View File
@@ -184,7 +184,7 @@ struct llama_model_loader {
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
void done_getting_tensors() const;
void done_getting_tensors(bool partial = false) const;
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
+27 -3
View File
@@ -1947,6 +1947,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
// checks
default:
{
// The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain
// attention KV cache for the MTP context instead of the hybrid wrapper.
const bool mtp_on_hybrid_qwen35 =
params.ctx_type == LLAMA_CONTEXT_TYPE_MTP &&
(arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE);
if (llm_arch_is_recurrent(arch)) {
res = new llama_memory_recurrent(
*this,
@@ -1955,8 +1961,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max,
cparams.n_rs_seq,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
} else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) {
// The main difference between hybrid architectures is the
// layer filters, so pick the right one here
llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
@@ -1971,6 +1978,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
filter_recr = [&](int32_t il) {
return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
};
} else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) {
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
filter_attn = [&, n_main](int32_t il) {
return (uint32_t)il < n_main && !hparams.is_recurrent(il);
};
filter_recr = [&, n_main](int32_t il) {
return (uint32_t)il < n_main && hparams.is_recurrent(il);
};
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
@@ -1988,6 +2003,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* recurrent_type_s */ GGML_TYPE_F32,
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* n_rs_seq */ cparams.n_rs_seq,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
@@ -2006,6 +2022,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* n_rs_seq */ cparams.n_rs_seq,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
@@ -2013,6 +2030,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
}
} else {
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_filter_cb filter = nullptr;
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
reuse = [&](int32_t il) {
@@ -2024,6 +2042,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
};
}
if (mtp_on_hybrid_qwen35) {
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; };
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
@@ -2039,7 +2062,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
cparams.n_ubatch,
1,
nullptr,
filter,
reuse);
} else {
GGML_ASSERT(!hparams.is_swa_any());
@@ -2056,7 +2079,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
filter,
nullptr);
}
}
@@ -2159,6 +2182,7 @@ int32_t llama_model_n_swa(const llama_model * model) {
return model->hparams.n_swa;
}
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
return model->hparams.n_cls_out;
}
+142 -1
View File
@@ -1,6 +1,7 @@
#include "models.h"
#include "llama-impl.h"
#include "llama-memory-recurrent.h"
// utility to get one slice from the third dimension
// input dim: [x, y, c, b]
@@ -397,7 +398,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
// K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net.
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs);
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d);
if (n_tokens == 1) {
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
} else {
@@ -443,3 +446,141 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
return build_delta_net_chunking(q, k, v, g, b, s, il);
}
bool llm_build_delta_net_base::keep_rs() const {
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
return cparams.n_rs_seq > 0
&& n_seq_tokens > 1
&& (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq;
}
ggml_tensor * llm_build_delta_net_base::build_conv_state(
llm_graph_input_rs * inp,
ggml_tensor * conv_states_all,
ggml_tensor * qkv_mixed,
int64_t conv_kernel_size,
int64_t conv_channels,
int il) {
const auto * mctx_cur = inp->mctx;
const auto kv_head = mctx_cur->get_head();
const uint32_t mem_size = mctx_cur->get_size();
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const bool keep = keep_rs();
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
cb(conv_states, "conv_states", il);
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
cb(conv_states, "conv_states_reshaped", il);
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
cb(qkv_mixed, "qkv_mixed_transposed", il);
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
if (!keep) {
ggml_tensor * last_conv_states =
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
cb(last_conv_states, "last_conv_states", il);
ggml_tensor * state_update_target =
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
} else {
const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
const size_t row_size = row_count * ggml_element_size(conv_states_all);
for (int64_t t = 1; t <= n_seq_tokens; ++t) {
const uint32_t slot = (uint32_t)(n_seq_tokens - t);
ggml_tensor * src =
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs,
conv_input->nb[1], conv_input->nb[2],
t * ggml_element_size(conv_input));
ggml_tensor * dst =
ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs,
conv_states_all->nb[1],
((size_t) slot * mem_size + kv_head) * row_size);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
}
}
return conv_input;
}
ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
llm_graph_input_rs * inp,
ggml_tensor * ssm_states_all,
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * b,
ggml_tensor * s,
int il) {
const auto * mctx_cur = inp->mctx;
const auto kv_head = mctx_cur->get_head();
const uint32_t mem_size = mctx_cur->get_size();
const int64_t S_v = s->ne[0];
const int64_t H_v = s->ne[2];
const int64_t n_seqs = s->ne[3];
const int64_t n_seq_tokens = q->ne[2];
if (!keep_rs()) {
auto attn_out = build_delta_net(q, k, v, g, b, s, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, new_state,
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
return output;
}
const int64_t D = S_v * S_v * H_v;
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
// TODO: remove pad + simplify
ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d);
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs;
const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs;
ggml_tensor * output = ggml_view_4d(ctx0, gdn_out,
S_v, H_v, n_seq_tokens, n_seqs,
ggml_row_size(gdn_out->type, S_v),
ggml_row_size(gdn_out->type, S_v * H_v),
ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens),
0);
cb(output, "attn_output", il);
const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all);
for (int64_t k_i = 0; k_i < K; ++k_i) {
const uint32_t cache_slot = (uint32_t) (K - 1 - k_i);
ggml_tensor * src = ggml_view_4d(ctx0, gdn_out,
S_v, S_v, H_v, n_seqs,
ggml_row_size(gdn_out->type, S_v),
ggml_row_size(gdn_out->type, S_v * S_v),
ggml_row_size(gdn_out->type, S_v * S_v * H_v),
ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap));
ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
((size_t) cache_slot * mem_size + kv_head) * row_size);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
}
return output;
}
+35 -1
View File
@@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context {
ggml_tensor * s,
int il);
// use the ggml_gated_delta_net fused operator
// use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs))
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused(
ggml_tensor * q,
ggml_tensor * k,
@@ -65,6 +65,32 @@ struct llm_build_delta_net_base : public llm_graph_context {
ggml_tensor * b,
ggml_tensor * s,
int il);
// true when speculative rollback is enabled and the batch fits in the rs cache
bool keep_rs() const;
// read conv state from cache, concat with qkv_mixed, write back (single slot or per-token)
// qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs)
ggml_tensor * build_conv_state(
llm_graph_input_rs * inp,
ggml_tensor * conv_states_all,
ggml_tensor * qkv_mixed,
int64_t conv_kernel_size,
int64_t conv_channels,
int il);
// run delta-net attention and write the new recurrent state(s) back to ssm_states_all
// s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs)
ggml_tensor * build_recurrent_attn(
llm_graph_input_rs * inp,
ggml_tensor * ssm_states_all,
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * b,
ggml_tensor * s,
int il);
};
struct llm_build_rwkv6_base : public llm_graph_context {
@@ -1739,6 +1765,10 @@ struct llama_model_qwen35 : public llama_model_base {
const llama_model & model;
};
struct graph_mtp : public llm_graph_context {
graph_mtp(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
@@ -1781,6 +1811,10 @@ struct llama_model_qwen35moe : public llama_model_base {
const llama_model & model;
};
struct graph_mtp : public llm_graph_context {
graph_mtp(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
+225 -71
View File
@@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Mark recurrent layers (linear attention layers)
// NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
// Mark recurrent layers (linear attention layers). MTP layers are dense
// attention-only and must be flagged non-recurrent.
{
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
}
}
switch (hparams.n_layer) {
switch (hparams.n_layer - hparams.nextn_predict_layers) {
case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break;
case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break;
case 64: type = LLM_TYPE_27B; break;
@@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) {
}
}
void llama_model_qwen35::load_arch_tensors(llama_model_loader &) {
void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) {
LLAMA_LOAD_LOCALS;
const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
const bool mtp_only = (hparams.nextn_predict_layers > 0) &&
(ml.get_weight("blk.0.attn_norm.weight") == nullptr);
const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
// output
@@ -43,6 +54,9 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
}
auto load_block_trunk = [&](int il, int flags) {
auto & layer = layers[il];
// Calculate dimensions from hyperparameters
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t head_v_dim = hparams.ssm_d_state;
@@ -52,41 +66,73 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) {
const int64_t value_dim = head_v_dim * n_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
if (!hparams.is_recurrent(i)) {
if (!hparams.is_recurrent(il)) {
// Attention layers
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags);
// Q/K normalization for attention layers
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags);
} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags);
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags);
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags);
}
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags);
};
auto load_block_mtp = [&](int il) {
auto & layer = layers[il];
// MTP block looks like a full-attention Qwen3.5 decoder block.
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0);
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0);
// NextN-specific tensors that define the MTP block.
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED);
};
for (int i = 0; i < (int) n_main; ++i) {
load_block_trunk(i, trunk_flags);
}
for (int i = (int) n_main; i < n_layer; ++i) {
load_block_mtp(i);
}
}
std::unique_ptr<llm_graph_context> llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const {
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
return std::make_unique<graph_mtp>(*this, params);
}
return std::make_unique<graph>(*this, params);
}
@@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
@@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -160,6 +208,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
}
cur = inpL;
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
@@ -297,8 +348,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
const int64_t head_v_dim = d_inner / num_v_heads;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = mctx_cur->get_head();
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
@@ -328,41 +377,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
// Get convolution states from cache
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// Build the convolution states tensor
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
cb(conv_states, "conv_states", il);
// Calculate convolution kernel size
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
const int64_t conv_kernel_size = conv_kernel->ne[0];
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
cb(conv_states, "conv_states_reshaped", il);
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
cb(qkv_mixed, "qkv_mixed_transposed", il);
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
// Update convolution state cache
// Extract the last (conv_kernel_size - 1) states from conv_input
ggml_tensor * last_conv_states =
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
cb(last_conv_states, "last_conv_states", il);
ggml_tensor * state_update_target =
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -413,7 +435,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes
// note: need explicit repeat only if we are not using the fused GDN
// note: need explicit repeat only if we are not using the fused GDN.
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
@@ -424,18 +446,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
// Update the recurrent states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, new_state,
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
@@ -471,3 +482,146 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons
return cur;
}
// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series
llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params)
: llm_graph_context(params) {
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0");
GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block");
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
// The MTP block lives at the source file's original layer index.
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_input(inp->embd);
ggml_set_name(inp->embd, "mtp_h_input");
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * h_input = inp->embd;
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
cb(tok_embd, "mtp_tok_embd", il);
res->add_input(std::move(inp));
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
cb(e_norm, "mtp_enorm", il);
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
cb(concat, "mtp_concat", il);
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
cb(cur, "mtp_eh_proj", il);
ggml_tensor * inpSA = cur;
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_norm", il);
ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
cb(Qcur_full, "mtp_Qcur_full", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
0);
Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "mtp_Qcur_normed", il);
ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
ggml_element_size(Qcur_full) * n_embd_head);
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "mtp_gate", il);
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "mtp_Kcur_normed", il);
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
cb(Vcur, "mtp_Vcur", il);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
const float kq_scale = hparams.f_attention_scale == 0.0f
? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp_attn,
nullptr, nullptr, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "mtp_attn_pregate", il);
cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
cur = build_lora_mm(layer.wo, cur, layer.wo_s);
cb(cur, "mtp_attn_out", il);
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "mtp_attn_residual", il);
ggml_tensor * ffn_residual = cur;
cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_post_norm", il);
cur = build_ffn(cur,
layer.ffn_up, nullptr, layer.ffn_up_s,
layer.ffn_gate, nullptr, layer.ffn_gate_s,
layer.ffn_down, nullptr, layer.ffn_down_s,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "mtp_ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "mtp_post_ffn", il);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
// (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.)
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm");
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
cb(cur, "mtp_shared_head_norm", -1);
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)");
cur = build_lora_mm(head_w, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
+271 -75
View File
@@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Mark recurrent layers (linear attention layers)
// NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer");
// Mark recurrent layers (linear attention layers). MTP layers are dense
// attention-only and must be flagged non-recurrent.
{
const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers;
uint32_t full_attn_interval = 4;
ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false);
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0);
hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0);
}
}
switch (hparams.n_layer) {
switch (hparams.n_layer - hparams.nextn_predict_layers) {
case 40: type = LLM_TYPE_35B_A3B; break;
case 48: type = LLM_TYPE_122B_A10B; break;
case 60: type = LLM_TYPE_397B_A17B; break;
@@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) {
}
}
void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) {
void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) {
LLAMA_LOAD_LOCALS;
const uint32_t n_main = n_layer - hparams.nextn_predict_layers;
const bool mtp_only = (hparams.nextn_predict_layers > 0) &&
(ml.get_weight("blk.0.attn_norm.weight") == nullptr);
const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
// output
@@ -46,7 +57,11 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
}
auto load_block_trunk = [&](int il, int flags) {
auto & layer = layers[il];
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
// Calculate dimensions from hyperparameters
const int64_t head_k_dim = hparams.ssm_d_state;
@@ -57,49 +72,90 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) {
const int64_t value_dim = head_v_dim * n_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
if (!hparams.is_recurrent(i)) {
if (!hparams.is_recurrent(il)) {
// Attention layers
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags);
// Q/K normalization for attention layers
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags);
} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags);
layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags);
layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags);
}
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Routed experts
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags);
create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags);
// Shared experts
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags);
};
auto load_block_mtp = [&](int il) {
auto & layer = layers[il];
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0);
// MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN.
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0);
create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0);
// Routed experts
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0);
create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0);
// Shared experts
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0);
// NextN-specific tensors that define the MTP block.
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED);
};
for (int i = 0; i < (int) n_main; ++i) {
load_block_trunk(i, trunk_flags);
}
for (int i = (int) n_main; i < n_layer; ++i) {
load_block_mtp(i);
}
}
std::unique_ptr<llm_graph_context> llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const {
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
return std::make_unique<graph_mtp>(*this, params);
}
return std::make_unique<graph>(*this, params);
}
@@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
@@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -173,6 +231,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
}
cur = inpL;
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
@@ -310,8 +371,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
const int64_t head_v_dim = d_inner / num_v_heads;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = mctx_cur->get_head();
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
@@ -341,41 +400,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
// Get convolution states from cache
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// Build the convolution states tensor
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
cb(conv_states, "conv_states", il);
// Calculate convolution kernel size
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
const int64_t conv_kernel_size = conv_kernel->ne[0];
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
cb(conv_states, "conv_states_reshaped", il);
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
cb(qkv_mixed, "qkv_mixed_transposed", il);
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
// Update convolution state cache
// Extract the last (conv_kernel_size - 1) states from conv_input
ggml_tensor * last_conv_states =
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
cb(last_conv_states, "last_conv_states", il);
ggml_tensor * state_update_target =
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -426,7 +458,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes
// note: need explicit repeat only if we are not using the fused GDN
// note: need explicit repeat only if we are not using the fused GDN.
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
@@ -437,18 +469,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
// Update the recurrent states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, new_state,
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
@@ -525,3 +546,178 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c
return cur;
}
// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE
llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params)
: llm_graph_context(params) {
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0");
GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block");
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm");
GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm");
GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp");
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
ggml_set_input(inp->embd);
ggml_set_name(inp->embd, "mtp_h_input");
ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd;
ggml_tensor * h_input = inp->embd;
ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens);
cb(tok_embd, "mtp_tok_embd", il);
res->add_input(std::move(inp));
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il);
cb(e_norm, "mtp_enorm", il);
ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0);
cb(concat, "mtp_concat", il);
ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat);
cb(cur, "mtp_eh_proj", il);
ggml_tensor * inpSA = cur;
cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_norm", il);
ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s);
cb(Qcur_full, "mtp_Qcur_full", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
0);
Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "mtp_Qcur_normed", il);
ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full,
n_embd_head, n_head, n_tokens,
ggml_element_size(Qcur_full) * n_embd_head * 2,
ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
ggml_element_size(Qcur_full) * n_embd_head);
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "mtp_gate", il);
ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "mtp_Kcur_normed", il);
ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
cb(Vcur, "mtp_Vcur", il);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
const float kq_scale = hparams.f_attention_scale == 0.0f
? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp_attn,
nullptr, nullptr, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "mtp_attn_pregate", il);
cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate));
cur = build_lora_mm(layer.wo, cur, layer.wo_s);
cb(cur, "mtp_attn_out", il);
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "mtp_attn_residual", il);
ggml_tensor * ffn_residual = cur;
cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "mtp_attn_post_norm", il);
// MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe).
ggml_tensor * moe_out =
build_moe_ffn(cur,
layer.ffn_gate_inp,
layer.ffn_up_exps,
layer.ffn_gate_exps,
layer.ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
nullptr, layer.ffn_gate_up_exps,
layer.ffn_up_exps_s,
layer.ffn_gate_exps_s,
layer.ffn_down_exps_s);
cb(moe_out, "mtp_ffn_moe_out", il);
if (layer.ffn_up_shexp != nullptr) {
ggml_tensor * ffn_shexp =
build_ffn(cur,
layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s,
layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s,
layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "mtp_ffn_shexp", il);
ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur);
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il);
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
cb(ffn_shexp, "mtp_ffn_shexp_gated", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
} else {
cur = moe_out;
}
cb(cur, "mtp_ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "mtp_post_ffn", il);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm");
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
cb(cur, "mtp_shared_head_norm", -1);
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)");
cur = build_lora_mm(head_w, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
+2 -42
View File
@@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear(
const int64_t head_v_dim = d_inner / num_v_heads;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = mctx_cur->get_head();
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
@@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear(
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
// Get convolution states from cache
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// Build the convolution states tensor
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
cb(conv_states, "conv_states", il);
// Calculate convolution kernel size
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
const int64_t conv_kernel_size = conv_kernel->ne[0];
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
cb(conv_states, "conv_states_reshaped", il);
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
cb(qkv_mixed, "qkv_mixed_transposed", il);
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
// Update convolution state cache
// Extract the last (conv_kernel_size - 1) states from conv_input
ggml_tensor * last_conv_states =
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
cb(last_conv_states, "last_conv_states", il);
ggml_tensor * state_update_target =
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
cb(output, "attn_output", il);
cb(new_state, "new_state", il);
// Update the recurrent states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, new_state,
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+3
View File
@@ -252,6 +252,9 @@ llama_build_and_test(test-backend-sampler.cpp LABEL "model")
llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -m "${MODEL_DEST}")
set_tests_properties(test-state-restore-fragmented PROPERTIES FIXTURES_REQUIRED test-download-model)
llama_build_and_test(test-recurrent-state-rollback.cpp LABEL "model" ARGS -m "${MODEL_DEST}")
set_tests_properties(test-recurrent-state-rollback PROPERTIES FIXTURES_REQUIRED test-download-model)
if (NOT GGML_BACKEND_DL)
# these tests use the backends directly and cannot be built with dynamic loading
llama_build_and_test(test-barrier.cpp)
+17 -4
View File
@@ -3832,16 +3832,17 @@ struct test_gated_delta_net : public test_case {
const int v_repeat;
const bool permuted;
const bool kda;
const int64_t K; // snapshot slot count: 1 = final-only, >1 = last K states
std::string vars() override {
return VARS_TO_STR8(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda);
return VARS_TO_STR9(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda, K);
}
test_gated_delta_net(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1,
int v_repeat = 1, bool permuted = false, bool kda = false)
int v_repeat = 1, bool permuted = false, bool kda = false, int64_t K = 1)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs),
v_repeat(v_repeat), permuted(permuted), kda(kda) {}
v_repeat(v_repeat), permuted(permuted), kda(kda), K(K) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q;
@@ -3863,7 +3864,7 @@ struct test_gated_delta_net : public test_case {
const int64_t g_ne0 = kda ? head_size : 1;
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs);
ggml_tensor * state = ggml_new_tensor_3d(ctx, type, head_size * v_repeat * head_size * head_count, K, n_seqs);
ggml_set_name(g, "g");
ggml_set_name(beta, "beta");
ggml_set_name(state, "state");
@@ -9034,6 +9035,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true));
// K > 1: output keeps the last min(n_tokens, K) per-token snapshots in the trailing K-token region.
// exact-match cases (K == n_seq_tokens):
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 2, 1, 1, false, false, /*K=*/2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, false, /*K=*/4));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, false, /*K=*/4));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 128, 4, 1, 1, false, false, /*K=*/4));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true, /*K=*/4));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true, /*K=*/4));
// overflow: n_tokens > K — only the last K snapshots kept.
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));
+185
View File
@@ -0,0 +1,185 @@
#include "arg.h"
#include "common.h"
#include "llama.h"
#include <algorithm>
#include <clocale>
#include <cmath>
#include <cstdio>
#include <vector>
static llama_context * make_ctx(const common_params & params, llama_model * model) {
auto cparams = common_context_params_to_llama(params);
cparams.n_seq_max = 1;
cparams.n_rs_seq = 8;
cparams.n_batch = std::max(cparams.n_batch, (uint32_t) (cparams.n_rs_seq + 1));
cparams.n_ubatch = std::max(cparams.n_ubatch, (uint32_t) (cparams.n_rs_seq + 1));
return llama_init_from_model(model, cparams);
}
static bool decode_tokens(llama_context * ctx, const std::vector<llama_token> & tokens, uint32_t count) {
llama_batch batch = llama_batch_init(count, 0, 1);
for (uint32_t pos = 0; pos < count; ++pos) {
common_batch_add(batch, tokens[pos], pos, { 0 }, false);
}
const bool ok = llama_decode(ctx, batch) == 0;
llama_batch_free(batch);
return ok;
}
static bool decode_one(llama_context * ctx, llama_token tok, llama_pos pos) {
llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, tok, pos, { 0 }, true);
const bool ok = llama_decode(ctx, batch) == 0;
llama_batch_free(batch);
return ok;
}
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
common_params params;
params.sampling.seed = 1234;
params.n_predict = 1;
common_init();
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
ggml_backend_load_all();
common_init_result_ptr llama_init = common_init_from_params(params);
llama_model * model = llama_init->model();
if (model == nullptr) {
fprintf(stderr, "%s : failed to init model\n", __func__);
return 1;
}
if (!llama_model_is_recurrent(model) && !llama_model_is_hybrid(model)) {
fprintf(stderr, "%s : skipping for non-recurrent model\n", __func__);
return 0;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
llama_context * ctx_src = make_ctx(params, model);
llama_context * ctx_dst = make_ctx(params, model);
if (ctx_src == nullptr || ctx_dst == nullptr) {
fprintf(stderr, "%s : failed to init contexts\n", __func__);
return 1;
}
if (llama_n_rs_seq(ctx_src) == 0) {
fprintf(stderr, "%s : skipping because n_rs_seq is disabled\n", __func__);
llama_free(ctx_src);
llama_free(ctx_dst);
return 0;
}
std::vector<llama_token> tokens = common_tokenize(ctx_src, "The quick brown fox jumps", true);
const uint32_t n_rs_seq = llama_n_rs_seq(ctx_src);
if (tokens.size() > n_rs_seq + 1) {
tokens.resize(n_rs_seq + 1);
}
if (tokens.size() < 2) {
fprintf(stderr, "%s : not enough prompt tokens\n", __func__);
return 1;
}
const uint32_t n_tokens = tokens.size();
const llama_token last_tok = tokens.back();
const llama_pos last_pos = (llama_pos) n_tokens - 2;
// Decode the full prompt on the source, then roll back the last position.
// Rollback leaves the recurrent memory in a snapshot state (rs_idx != 0).
if (!decode_tokens(ctx_src, tokens, n_tokens)) {
fprintf(stderr, "%s : failed to decode prompt\n", __func__);
return 1;
}
if (!llama_memory_seq_rm(llama_get_memory(ctx_src), 0, last_pos, -1)) {
fprintf(stderr, "%s : rollback failed\n", __func__);
return 1;
}
// Save the rolled-back state and restore it into a fresh context.
common_prompt_checkpoint ckpt;
ckpt.update_tgt(ctx_src, 0, 0);
ckpt.load_tgt(ctx_dst, 0, 0);
// Replay the rolled-back token on both contexts and compare logits.
if (!decode_one(ctx_src, last_tok, last_pos) ||
!decode_one(ctx_dst, last_tok, last_pos)) {
fprintf(stderr, "%s : replay failed\n", __func__);
return 1;
}
const float * logits_src = llama_get_logits_ith(ctx_src, 0);
const float * logits_dst = llama_get_logits_ith(ctx_dst, 0);
if (logits_src == nullptr || logits_dst == nullptr) {
fprintf(stderr, "%s : missing logits\n", __func__);
return 1;
}
constexpr float eps = 1e-5f;
for (int i = 0; i < n_vocab; ++i) {
if (std::fabs(logits_src[i] - logits_dst[i]) > eps) {
fprintf(stderr, "%s : logits mismatch at token %d (%g != %g)\n",
__func__, i, (double) logits_src[i], (double) logits_dst[i]);
return 1;
}
}
// Repeat the load into a context that already has its own rollback state:
// groups 1..n_rs_seq hold a *different* prompt's history, and rs_idx[0] is
// non-zero at load time. The restore must wipe that state and still match.
llama_context * ctx_dirty = make_ctx(params, model);
if (ctx_dirty == nullptr) {
fprintf(stderr, "%s : failed to init dirty ctx\n", __func__);
return 1;
}
std::vector<llama_token> noise = tokens;
for (auto & t : noise) {
t = (t + 1) % n_vocab;
if (t < 0) {
t = 0;
}
}
if (!decode_tokens(ctx_dirty, noise, n_tokens)) {
fprintf(stderr, "%s : dirty prompt decode failed\n", __func__);
return 1;
}
if (!llama_memory_seq_rm(llama_get_memory(ctx_dirty), 0, last_pos, -1)) {
fprintf(stderr, "%s : dirty rollback failed\n", __func__);
return 1;
}
ckpt.load_tgt(ctx_dirty, 0, 0);
if (!decode_one(ctx_dirty, last_tok, last_pos)) {
fprintf(stderr, "%s : dirty replay failed\n", __func__);
return 1;
}
const float * logits_dirty = llama_get_logits_ith(ctx_dirty, 0);
if (logits_dirty == nullptr) {
fprintf(stderr, "%s : missing dirty logits\n", __func__);
return 1;
}
for (int i = 0; i < n_vocab; ++i) {
if (std::fabs(logits_src[i] - logits_dirty[i]) > eps) {
fprintf(stderr, "%s : dirty-ctx logits mismatch at token %d (%g != %g)\n",
__func__, i, (double) logits_src[i], (double) logits_dirty[i]);
return 1;
}
}
fprintf(stderr, "%s : recurrent rollback checkpoint restored successfully\n", __func__);
llama_free(ctx_src);
llama_free(ctx_dst);
llama_free(ctx_dirty);
return 0;
}
+3 -4
View File
@@ -55,7 +55,6 @@
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)<br/>(env: LLAMA_ARG_RPC) |
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
@@ -94,8 +93,8 @@
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:<br/> - 0: generic output<br/> - 1: error<br/> - 2: warning<br/> - 3: info<br/> - 4: debug<br/>(default: 3)<br/><br/>(env: LLAMA_LOG_VERBOSITY) |
| `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
| `--log-prefix, --no-log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_ARG_LOG_PREFIX) |
| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_ARG_LOG_TIMESTAMPS) |
| `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) |
| `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) |
@@ -199,7 +198,7 @@
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
| `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) |
+2 -3
View File
@@ -138,7 +138,6 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)<br/>(env: LLAMA_ARG_RPC) |
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
@@ -177,8 +176,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:<br/> - 0: generic output<br/> - 1: error<br/> - 2: warning<br/> - 3: info<br/> - 4: debug<br/>(default: 3)<br/><br/>(env: LLAMA_LOG_VERBOSITY) |
| `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
| `--log-prefix, --no-log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_ARG_LOG_PREFIX) |
| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_ARG_LOG_TIMESTAMPS) |
| `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) |
| `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) |
+11 -8
View File
@@ -72,7 +72,6 @@ For the full list of features, please refer to [server's changelog](https://gith
| `-ctk, --cache-type-k TYPE` | KV cache data type for K<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
| `--rpc SERVERS` | comma-separated list of RPC servers (host:port)<br/>(env: LLAMA_ARG_RPC) |
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
@@ -111,8 +110,8 @@ For the full list of features, please refer to [server's changelog](https://gith
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:<br/> - 0: generic output<br/> - 1: error<br/> - 2: warning<br/> - 3: info<br/> - 4: debug<br/>(default: 3)<br/><br/>(env: LLAMA_LOG_VERBOSITY) |
| `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
| `--log-prefix, --no-log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_ARG_LOG_PREFIX) |
| `--log-timestamps, --no-log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_ARG_LOG_TIMESTAMPS) |
| `--spec-draft-type-k, -ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_K) |
| `--spec-draft-type-v, -ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_SPEC_DRAFT_CACHE_TYPE_V) |
@@ -189,11 +188,15 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)<br/>(env: LLAMA_ARG_REUSE_PORT) |
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
| `--ui-config JSON` / `--webui-config JSON` (deprecated) | JSON that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG / LLAMA_ARG_WEBUI_CONFIG) |
| `--ui-config-file PATH` / `--webui-config-file PATH` (deprecated) | JSON file that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG_FILE / LLAMA_ARG_WEBUI_CONFIG_FILE) |
| `--ui-mcp-proxy, --no-ui-mcp-proxy` / `--webui-mcp-proxy, --no-webui-mcp-proxy` (deprecated) | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)<br/>(env: LLAMA_ARG_UI_MCP_PROXY / LLAMA_ARG_WEBUI_MCP_PROXY) |
| `--webui-config JSON` | [DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG) |
| `--ui-config JSON` | JSON that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG) |
| `--webui-config-file PATH` | [DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) |
| `--ui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)<br/>(env: LLAMA_ARG_UI_CONFIG_FILE) |
| `--webui-mcp-proxy, --no-webui-mcp-proxy` | [DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy<br/>(env: LLAMA_ARG_WEBUI_MCP_PROXY) |
| `--ui-mcp-proxy, --no-ui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)<br/>(env: LLAMA_ARG_UI_MCP_PROXY) |
| `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)<br/>specify "all" to enable all tools<br/>available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime<br/>(env: LLAMA_ARG_TOOLS) |
| `--ui, --no-ui` / `--webui, --no-webui` (deprecated) | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_UI / LLAMA_ARG_WEBUI) |
| `--webui, --no-webui` | [DEPRECATED: use --ui/--no-ui] whether to enable the Web UI<br/>(env: LLAMA_ARG_WEBUI) |
| `--ui, --no-ui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_UI) |
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
| `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)<br/>(env: LLAMA_API_KEY) |
@@ -248,7 +251,7 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--spec-draft-device, -devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
| `--spec-draft-ngl, -ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
| `--spec-draft-model, -md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_MODEL) |
| `--spec-type none,draft-simple,draft-eagle3,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-type none,draft-simple,draft-eagle3,draft-mtp,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-mod,ngram-cache` | comma-separated list of types of speculative decoding to use (default: none)<br/><br/>(env: LLAMA_ARG_SPEC_TYPE) |
| `--spec-ngram-mod-n-min N` | minimum number of ngram tokens to use for ngram-based speculative decoding (default: 48) |
| `--spec-ngram-mod-n-max N` | maximum number of ngram tokens to use for ngram-based speculative decoding (default: 64) |
| `--spec-ngram-mod-n-match N` | ngram-mod lookup length (default: 24) |
+94 -47
View File
@@ -145,9 +145,9 @@ struct server_slot {
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx_tgt), id, -1, -1);
common_context_seq_rm(ctx_tgt, id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), id, -1, -1);
common_context_seq_rm(ctx_dft, id, -1, -1);
}
prompt.tokens.clear();
@@ -238,8 +238,14 @@ struct server_slot {
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
}
bool need_embd() const {
GGML_ASSERT(task);
return task->need_embd() || (spec && common_speculative_need_embd(spec));
}
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
// (MTP supports splitting — uses task->need_embd() not need_embd())
bool can_split() const {
GGML_ASSERT(task);
@@ -511,12 +517,12 @@ struct server_slot {
void copy_state_to(server_slot & other) const {
GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT);
llama_memory_seq_rm(llama_get_memory(ctx_tgt), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_tgt), id, other.id, -1, -1);
common_context_seq_rm(ctx_tgt, other.id, -1, -1);
common_context_seq_cp(ctx_tgt, id, other.id, -1, -1);
if (ctx_dft) {
llama_memory_seq_rm(llama_get_memory(ctx_dft), other.id, -1, -1);
llama_memory_seq_cp(llama_get_memory(ctx_dft), id, other.id, -1, -1);
common_context_seq_rm(ctx_dft, other.id, -1, -1);
common_context_seq_cp(ctx_dft, id, other.id, -1, -1);
}
other.n_decoded = n_decoded;
@@ -775,10 +781,40 @@ private:
}
auto cparams = common_context_params_to_llama(params_dft);
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
if (spec_mtp) {
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}
// note: for small models maybe we can set this to the maximum possible draft from all speculative types
// the extra memory for small models is likely negligible?
cparams.n_rs_seq = 0;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) {
SRV_INF("creating MTP draft context against the target model '%s'\n",
params_base.model.path.c_str());
auto cparams_mtp = common_context_params_to_llama(params_base);
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
cparams_mtp.n_rs_seq = 0;
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
if (ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create MTP context\n");
return false;
}
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
@@ -2194,12 +2230,12 @@ private:
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
common_context_seq_rm (ctx_tgt, slot.id, n_keep , n_keep + n_discard);
common_context_seq_add(ctx_tgt, slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
if (ctx_dft) {
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
common_context_seq_rm (ctx_dft.get(), slot.id, n_keep , n_keep + n_discard);
common_context_seq_add(ctx_dft.get(), slot.id, n_keep + n_discard, slot.prompt.tokens.pos_next(), -n_discard);
}
// add generated tokens to cache
@@ -2306,14 +2342,23 @@ private:
slot.n_draft_total += draft.size();
// TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
if (ctx_dft) {
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, ckpt.pos_max + 1, -1);
if (ctx_dft) {
if (use_ckpt_dft) {
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
common_context_seq_rm(ctx_dft.get(), slot.id, ckpt.pos_max + 1, -1);
}
if (!draft.empty()) {
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
const bool use_ckpt_tgt =
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_tgt));
const bool use_ckpt_dft =
(ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && draft.size() > llama_n_rs_seq(ctx_dft.get()));
if (use_ckpt_tgt) {
//const int64_t t_start = ggml_time_us();
@@ -2328,6 +2373,10 @@ private:
(float) ckpt.size() / 1024 / 1024,
(float) ckpt.data_dft.size() / 1024 / 1024);
}
if (use_ckpt_dft) {
ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
}
}
}
@@ -2499,12 +2548,12 @@ private:
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
llama_memory_seq_rm (llama_get_memory(ctx_tgt), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_tgt), slot.id, head_c, head_c + n_match, kv_shift);
common_context_seq_rm (ctx_tgt, slot.id, head_p, head_c);
common_context_seq_add(ctx_tgt, slot.id, head_c, head_c + n_match, kv_shift);
if (ctx_dft) {
llama_memory_seq_rm (llama_get_memory(ctx_dft.get()), slot.id, head_p, head_c);
llama_memory_seq_add(llama_get_memory(ctx_dft.get()), slot.id, head_c, head_c + n_match, kv_shift);
common_context_seq_rm (ctx_dft.get(), slot.id, head_p, head_c);
common_context_seq_add(ctx_dft.get(), slot.id, head_c, head_c + n_match, kv_shift);
}
for (size_t i = 0; i < n_match; i++) {
@@ -2667,17 +2716,9 @@ private:
SLT_TRC(slot, "cached n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
slot.prompt_clear(true);
// there is no common part left
slot.n_prompt_tokens_cache = 0;
} else {
if (ctx_dft && !llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) {
GGML_ABORT("failed to truncate draft context\n");
}
common_context_seq_rm(ctx_tgt, slot.id, p0, -1);
if (ctx_dft) {
common_context_seq_rm(ctx_dft.get(), slot.id, p0, -1);
}
// If using an alora, there may be uncached tokens that come
@@ -2703,9 +2744,11 @@ private:
// checkpoints are created only if:
// - the model does not support partial sequence removal
// - the model uses SWA (and we are not using `swa_full`)
// - the model supports partial sequence removal but only up to a fixed bound
do_checkpoint = do_checkpoint && (
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
(n_swa > 0));
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS ||
n_swa > 0);
bool has_mtmd = false;
@@ -2758,12 +2801,14 @@ private:
break;
}
// embedding requires all tokens in the batch to be output
// embedding requires all tokens in the batch to be output;
// MTP also wants logits at every prompt position so the
// streaming hook can mirror t_h_pre_norm into ctx_dft.
common_batch_add(batch,
cur_tok,
slot.prompt.tokens.pos_next(),
{ slot.id },
slot.task->need_embd());
slot.need_embd());
slot.prompt.tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
@@ -2877,7 +2922,7 @@ private:
slot_batched->lora[alora_disabled_id].scale = alora_scale;
}
llama_set_embeddings(ctx_tgt, slot_batched->task->need_embd());
llama_set_embeddings(ctx_tgt, slot_batched->need_embd());
}
if (batch.n_tokens == 0) {
@@ -3140,13 +3185,8 @@ private:
// verify and try to accept the draft
{
const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
// only save the sampler sampler state if we use checkpoints
common_sampler_ptr smpl_save;
if (use_ckpt_tgt) {
smpl_save.reset(common_sampler_clone(slot.smpl.get()));
}
// save the sampler sampler state in case we need to restore it
common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get()));
GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1);
auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft);
@@ -3154,8 +3194,14 @@ private:
GGML_ASSERT(accepted.size() >= 1);
const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size();
const bool use_ckpt_tgt =
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt));
// check for partial draft acceptance
if (accepted.size() < slot.spec_draft.size() + 1) {
if (n_rollback > 0) {
if (use_ckpt_tgt) {
if (trace > 0) {
SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size());
@@ -3171,13 +3217,13 @@ private:
{
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, ckpt.pos_max + 1, -1);
common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1);
}
if (slot.ctx_dft) {
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, ckpt.pos_max + 1, -1);
common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1);
}
slot.prompt.tokens.keep_first(ckpt.n_tokens);
@@ -3200,7 +3246,6 @@ private:
const auto ids = std::move(slot.spec_draft);
slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
@@ -3213,9 +3258,9 @@ private:
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(slot.ctx_tgt), slot.id, slot.prompt.tokens.pos_next(), -1);
common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1);
if (slot.ctx_dft) {
llama_memory_seq_rm(llama_get_memory(slot.ctx_dft), slot.id, slot.prompt.tokens.pos_next(), -1);
common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1);
}
for (size_t i = 0; i < ids.size(); ++i) {
@@ -3227,6 +3272,8 @@ private:
// TODO: set result.probs
slot.n_decoded += 1;
if (!process_token(result, slot)) {
slot.print_timings();
send_final_response(slot);