255582687b
* spec: support MTP * fix batch size * rename files * cont : simplify (#7) * MTP: clean-up (#9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion * mtp -> draft-mtp * remove unused llama_arch * add need_embd in speculative * llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. * fix pending state * vulkan: add GDN partial rollback * meta: extend check to axis 1 * metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi * delta_net_base: use ggml_pad instead of new_tensor * review: add need_rs_seq * review: rename part_bounded to n_rs * review: deslop comments * review: rename, add asserts * server : adjust checkpoint logic (#11) * server : adjust checkpoint logic * cont : rm asserts * server-context: fix early exit * spec : fix compatibility with n-gram and add TODOs (#13) * metal : cleanup * llama : fix faulty bitwise check in recurrent memory * server : disable RS-based MTP in combination with other spec types * spec : add TODOs * cont : fix comment * cont : update comment * common : fix logic for ngram + mtp compat * llama-memory: enable checkpointing with partial rollback * cont: add test-case for loading into a dirty ctx * llama-memory-recurrent: clear rs_idx in clear * download: fix mtp path * llama-arch: fix enorm op * docs: update docs * conversion: fix type annotations --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
73 lines
2.9 KiB
C++
73 lines
2.9 KiB
C++
#pragma once
|
|
|
|
#include "llama.h"
|
|
#include "common.h"
|
|
|
|
struct common_speculative;
|
|
|
|
// comma separated list the provided types
|
|
std::string common_speculative_type_name_str(const std::vector<enum common_speculative_type> & types);
|
|
|
|
// comma separated list of all types
|
|
const char * common_speculative_all_types_str();
|
|
|
|
// parse user provided types
|
|
std::vector<enum common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names);
|
|
|
|
// convert string to type
|
|
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
|
|
|
// convert type to string
|
|
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
|
|
|
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq);
|
|
|
|
void common_speculative_free(common_speculative * spec);
|
|
|
|
struct common_speculative_draft_params {
|
|
// this flag is used to chain the drafts through all the available implementations
|
|
// after the first successful draft from an implementation, we set it
|
|
// to false to prevent further drafts for that sequence
|
|
// at the end of the draft() call, all drafting flags will be reset to false
|
|
bool drafting = false;
|
|
|
|
// overrides individual configurations (-1 disabled)
|
|
// can be used to constraint the max draft based on the remaining context size
|
|
int32_t n_max = -1;
|
|
|
|
llama_pos n_past;
|
|
llama_token id_last;
|
|
|
|
// TODO: remove in the future by keeping track of the prompt from the _begin() call and the consecutive accept calls
|
|
const llama_tokens * prompt;
|
|
|
|
// the generated draft from the last _draft() call
|
|
llama_tokens * result;
|
|
};
|
|
|
|
common_speculative_draft_params & common_speculative_get_draft_params(common_speculative * spec, llama_seq_id seq_id);
|
|
|
|
// optionally call once at the beginning of a new generation
|
|
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt);
|
|
|
|
// process the batch and update the internal state of the speculative context
|
|
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
|
|
|
|
// true if any implementation requires target embeddings to be extracted
|
|
bool common_speculative_need_embd(common_speculative * spec);
|
|
|
|
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
|
void common_speculative_draft(common_speculative * spec);
|
|
|
|
// informs the speculative context that n_accepted tokens were accepted by the target model
|
|
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
|
|
|
|
// print statistics about the speculative decoding
|
|
void common_speculative_print_stats(const common_speculative * spec);
|
|
|
|
struct common_speculative_deleter {
|
|
void operator()(common_speculative * s) { common_speculative_free(s); }
|
|
};
|
|
|
|
typedef std::unique_ptr<common_speculative, common_speculative_deleter> common_speculative_ptr;
|