spec : update CLI arguments for better consistency (#22964)

* spec : update CLI arguments for better consistency

* cont : fix CLI arg message
This commit is contained in:
Georgi Gerganov
2026-05-13 09:15:39 +03:00
committed by GitHub
parent bcfe63fc53
commit 634275fbbb
6 changed files with 50 additions and 46 deletions
+2 -2
View File
@@ -2223,7 +2223,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
"comma separated list of RPC servers (host:port)",
"comma-separated list of RPC servers (host:port)",
[](common_params & params, const std::string & value) {
add_rpc_devices(value);
GGML_UNUSED(params);
@@ -3555,7 +3555,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
{"--spec-type"}, common_speculative_all_types_str(),
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
string_format("comma-separated list of types of speculative decoding to use (default: %s)\n",
common_speculative_type_name_str(params.speculative.types).c_str()),
[](common_params & params, const std::string & value) {
const auto enabled_types = string_split<std::string>(value, ',');
+4 -3
View File
@@ -157,9 +157,9 @@ enum common_params_sampling_config : uint64_t {
enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
@@ -342,6 +342,7 @@ struct common_params_speculative_ngram_cache {
struct common_params_speculative {
std::vector<enum common_speculative_type> types = { COMMON_SPECULATIVE_TYPE_NONE };
// used by Simple, MTP, Eagle3, etc. - all methods that require some kind of draft model
common_params_speculative_draft draft;
common_params_speculative_ngram_mod ngram_mod;
+39 -39
View File
@@ -21,8 +21,8 @@
const std::map<std::string, common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -145,15 +145,15 @@ struct common_speculative_impl {
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
};
struct common_speculative_state_draft : public common_speculative_impl {
struct common_speculative_impl_draft_simple : public common_speculative_impl {
common_params_speculative_draft params;
llama_batch batch;
std::vector<common_sampler_ptr> smpls;
common_speculative_state_draft(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT, n_seq)
common_speculative_impl_draft_simple(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, n_seq)
, params(params.draft)
{
auto * ctx_dft = this->params.ctx_dft;
@@ -206,7 +206,7 @@ struct common_speculative_state_draft : public common_speculative_impl {
}
}
~common_speculative_state_draft() override {
~common_speculative_impl_draft_simple() override {
llama_batch_free(batch);
}
@@ -340,11 +340,11 @@ struct common_speculative_state_draft : public common_speculative_impl {
}
};
struct common_speculative_state_eagle3 : public common_speculative_impl {
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
//common_params_speculative_eagle3 params;
common_speculative_state_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_EAGLE3, n_seq) {}
common_speculative_impl_draft_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) {}
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
@@ -365,13 +365,13 @@ struct common_speculative_state_eagle3 : public common_speculative_impl {
};
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_impl {
struct common_speculative_impl_ngram_simple : public common_speculative_impl {
common_params_speculative_ngram_map params;
// shared across all sequences
common_ngram_simple_config config;
common_speculative_state_ngram_simple(
common_speculative_impl_ngram_simple(
const common_params_speculative & params, uint32_t n_seq,
common_ngram_simple_config config)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
@@ -405,13 +405,13 @@ struct common_speculative_state_ngram_simple : public common_speculative_impl {
}
};
struct common_speculative_state_ngram_map_k : public common_speculative_impl {
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
common_params_speculative_ngram_map params;
// n_seq configs
std::vector<common_ngram_map> config;
common_speculative_state_ngram_map_k(
common_speculative_impl_ngram_map_k(
const common_params_speculative & params,
const common_ngram_map & config,
uint32_t n_seq)
@@ -453,7 +453,7 @@ struct common_speculative_state_ngram_map_k : public common_speculative_impl {
}
};
struct common_speculative_state_ngram_mod : public common_speculative_impl {
struct common_speculative_impl_ngram_mod : public common_speculative_impl {
common_params_speculative_ngram_mod params;
// shared across all sequences
@@ -475,7 +475,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_impl {
std::vector<seq_info> sinfos;
common_speculative_state_ngram_mod(
common_speculative_impl_ngram_mod(
const common_params_speculative & params,
uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq)
@@ -621,7 +621,7 @@ struct common_speculative_state_ngram_mod : public common_speculative_impl {
}
};
struct common_speculative_state_ngram_cache : public common_speculative_impl {
struct common_speculative_impl_ngram_cache : public common_speculative_impl {
common_params_speculative_ngram_cache params;
uint16_t n_draft;
@@ -639,7 +639,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_impl {
std::vector<seq_info> sinfos;
common_speculative_state_ngram_cache(
common_speculative_impl_ngram_cache(
const common_params_speculative & params,
uint32_t n_seq,
uint16_t n_draft,
@@ -775,7 +775,7 @@ static common_ngram_map get_common_ngram_map(
return common_ngram_map(size_key, size_value, key_only, min_hits);
}
static common_speculative_state_ngram_cache create_state_ngram_cache(
static common_speculative_impl_ngram_cache create_state_ngram_cache(
const common_speculative_config & config,
uint32_t n_seq,
const std::string & path_static,
@@ -786,7 +786,7 @@ static common_speculative_state_ngram_cache create_state_ngram_cache(
bool save_static = false;
bool save_dynamic = false;
common_speculative_state_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic);
common_speculative_impl_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic);
return state;
}
@@ -818,8 +818,8 @@ const char * common_speculative_all_types_str() {
std::string common_speculative_type_to_str(common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
@@ -872,9 +872,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
{
uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types);
bool has_draft = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT));
bool has_draft_model = !params.draft.mparams.path.empty();
bool has_draft_model_path = !params.draft.mparams.path.empty();
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
// bool has_mtp = false; // TODO: add MTP here
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
@@ -906,22 +906,22 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_draft) {
if (!has_draft_model) {
if (has_draft_simple) {
if (!has_draft_model_path) {
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
has_draft = false;
has_draft_simple = false;
}
} else if (has_draft_model) {
} else if (has_draft_model_path) {
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
has_draft = true;
has_draft_simple = true;
}
if (has_draft) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params));
if (has_draft_simple) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params));
}
// TODO: add MTP here
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
}
@@ -932,12 +932,12 @@ common_speculative * common_speculative_init(common_params_speculative & params,
switch (config.type) {
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
impls.push_back(std::make_unique<common_speculative_state_draft>(config.params, n_seq));
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: {
impls.push_back(std::make_unique<common_speculative_impl_draft_simple>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.params, n_seq));
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_impl_draft_eagle3>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
@@ -950,7 +950,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
auto state = std::make_unique<common_speculative_impl_ngram_simple>(
/* .params = */ config.params,
/* .n_seq = */ n_seq,
/* .state = */ config_simple
@@ -961,13 +961,13 @@ common_speculative * common_speculative_init(common_params_speculative & params,
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
impls.push_back(
std::make_unique<common_speculative_state_ngram_map_k>(
std::make_unique<common_speculative_impl_ngram_map_k>(
config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
impls.push_back(
std::make_unique<common_speculative_state_ngram_mod>(config.params, n_seq));
std::make_unique<common_speculative_impl_ngram_mod>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
@@ -975,7 +975,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
config, n_seq,
params.ngram_cache.lookup_cache_static,
params.ngram_cache.lookup_cache_dynamic);
impls.push_back(std::make_unique<common_speculative_state_ngram_cache>(state));
impls.push_back(std::make_unique<common_speculative_impl_ngram_cache>(state));
break;
}
default: