CUDA: fuse SSM_CONV + ADD(bias) + SILU (#22478)
This commit is contained in:
@@ -3579,6 +3579,49 @@ struct test_ssm_conv : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_CONV + GGML_OP_ADD (channel-wise bias, optional) + GGML_OP_UNARY(SILU) (fused operation)
|
||||
struct test_ssm_conv_bias_silu : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const std::array<int64_t, 4> ne_b;
|
||||
const bool fuse_bias;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "SSM_CONV_BIAS_SILU";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne_a, ne_b, fuse_bias);
|
||||
}
|
||||
|
||||
test_ssm_conv_bias_silu(ggml_type type, std::array<int64_t, 4> ne_a, std::array<int64_t, 4> ne_b,
|
||||
bool fuse_bias)
|
||||
: type(type), ne_a(ne_a), ne_b(ne_b), fuse_bias(fuse_bias) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
|
||||
|
||||
if (fuse_bias) {
|
||||
ggml_tensor * bias = ggml_new_tensor_1d(ctx, type, out->ne[0]);
|
||||
ggml_set_name(bias, "bias");
|
||||
out = ggml_add(ctx, out, bias);
|
||||
}
|
||||
|
||||
out = ggml_silu(ctx, out);
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_SCAN
|
||||
struct test_ssm_scan : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -7977,6 +8020,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
// fused ssm_conv + (optional) bias_add + silu. The bias-only graph (no silu) is intentionally
|
||||
// not tested since there's no fusion for that pattern in ggml_cuda_can_fuse.
|
||||
for (int64_t d_conv : {3, 4, 9}) {
|
||||
for (int64_t d_inner : {1024, 1536, 2048}) {
|
||||
for (bool fuse_bias : {false, true}) {
|
||||
// short token path (n_t <= 32)
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(
|
||||
GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(
|
||||
GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(
|
||||
GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
|
||||
// long token path (n_t > 32)
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(
|
||||
GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(
|
||||
GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
|
||||
@@ -8993,6 +9057,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // prefill
|
||||
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // generate
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
|
||||
|
||||
|
||||
Reference in New Issue
Block a user