Files
llama.cpp/src/llama-memory-hybrid.h
T
Aman Gupta 255582687b llama + spec: MTP Support (#22673)
* spec: support MTP

* fix batch size

* rename files

* cont : simplify (#7)

* MTP: clean-up (#9)

* MTP: clean-up

* review: use llama_context_type instead of llama_graph_type

* review: remove llama_model_has_mtp

* review: fix convert issues

* convert: fix pycheck

* review: formatting

* use `mtp-` for identifying mtp models

* convert: fix mtp conversion

* mtp -> draft-mtp

* remove unused llama_arch

* add need_embd in speculative

* llama: allow partial seq_rm for GDN models for speculative decoding

Currently speculative checkpoint needs to restart from a checkpoint
after some draft tokens are not accepted, this leads to some wastage in
running the target again. This PR adds the ability to rollback upto
`draft_max` by storing the GDN intermediates.

* fix pending state

* vulkan: add GDN partial rollback

* meta: extend check to axis 1

* metal: add GDN partial rollback

Extend the gated delta net kernel to store intermediate states for
partial rollback support on the Metal backend.

- Add K (snapshot slot count) as a function constant
- Read input state from slot 0 of the 3D state tensor
- Write intermediate states to different slots during token loop
- For K=1, maintain backward-compatible single-slot behavior

Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc

Assisted-by: llama.cpp:local pi

* delta_net_base: use ggml_pad instead of new_tensor

* review: add need_rs_seq

* review: rename part_bounded to n_rs

* review: deslop comments

* review: rename, add asserts

* server : adjust checkpoint logic (#11)

* server : adjust checkpoint logic

* cont : rm asserts

* server-context: fix early exit

* spec : fix compatibility with n-gram and add TODOs (#13)

* metal : cleanup

* llama : fix faulty bitwise check in recurrent memory

* server : disable RS-based MTP in combination with other spec types

* spec : add TODOs

* cont : fix comment

* cont : update comment

* common : fix logic for ngram + mtp compat

* llama-memory: enable checkpointing with partial rollback

* cont: add test-case for loading into a dirty ctx

* llama-memory-recurrent: clear rs_idx in clear

* download: fix mtp path

* llama-arch: fix enorm op

* docs: update docs

* conversion: fix type annotations

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-05-16 20:06:23 +08:00

141 lines
4.3 KiB
C++

#pragma once
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-memory.h"
#include "llama-memory-recurrent.h"
#include <memory>
#include <vector>
//
// llama_memory_hybrid
//
// utilizes instances of llama_memory_recurrent and llama_kv_cache to
// support models where each layer may be either attention-based or recurrent
class llama_memory_hybrid : public llama_memory_i {
public:
llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
uint32_t n_rs_seq,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn = nullptr,
const layer_filter_cb & filter_recr = nullptr);
~llama_memory_hybrid() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_memory_hybrid specific API
//
llama_kv_cache * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const;
private:
const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};
class llama_memory_hybrid_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// init failure
explicit llama_memory_hybrid_context(llama_memory_status status);
// init full
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
// init update
explicit llama_memory_hybrid_context(
llama_memory_hybrid * mem,
llama_context * lctx,
bool optimize);
// init success
llama_memory_hybrid_context(
llama_memory_hybrid * mem,
slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_context() = default;
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_memory_hybrid_context
//
const llama_kv_cache_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const;
private:
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
const llama_memory_context_ptr ctx_attn;
const llama_memory_context_ptr ctx_recr;
const llama_memory_status status;
};