Compare commits

..

48 Commits

Author SHA1 Message Date
ProgenyAlpha deee23863b vulkan: add GATED_DELTA_NET op support (#20334)
* vulkan: add GATED_DELTA_NET op support

Implements the fused gated delta net recurrence as a Vulkan compute
shader with full support for scalar gate, KDA vector gate, GQA
broadcast, multi-token sequences, and permuted (non-contiguous) q/k
inputs. Specialization constants select head size (32/64/128) and
KDA mode at pipeline creation time.

Passes all 13 test-backend-ops cases on AMD Radeon 890M (RADV GFX1150).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: optimize GATED_DELTA_NET shader (Phase 1)

- vec4 dot products on all inner loops (dp4 hardware intrinsic)
- Cache exp(g) in shared memory for KDA path, eliminating ~32K
  redundant global reads and ~16K redundant exp() calls per token
- vec4 fused decay + rank-1 update (3 vec4 ops vs 12 scalar ops)
- Add perf benchmark cases for GATED_DELTA_NET to test-backend-ops

KDA TG: +5.4% throughput. Non-KDA: no regressions.
13/13 test-backend-ops passing on AMD Radeon 890M (RADV GFX1150).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: address review feedback for GATED_DELTA_NET

Pipeline array refactor [3][2], A_TYPE/D_TYPE/FLOAT_TYPE shader macros,
scale in push constants, supports_op fix, dispatch restructuring.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: use FLOAT_TYPE for buffer/shared declarations, align formatting

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: add explicit FLOAT_TYPE casts for buffer loads

Wrap data_q, data_k, and data_g buffer reads with FLOAT_TYPE() casts
to ensure correct behavior across all Vulkan configurations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: fix Q/K broadcast for interleaved head layout

Adapt to the interleaved broadcast convention from #20340:
head_id / rq1 → head_id % neq1

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Progeny Alpha <ProgenyAlpha@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 11:32:04 +01:00
Sigbjørn Skjæret c3e3f9e533 convert : better mtp check and fix return [no ci] (#20419) 2026-03-12 10:04:20 +01:00
ProgenyAlpha 40c550d4f6 vulkan: fix SSM_CONV PP scaling with large ubatch sizes (#20379)
* vulkan: optimize SSM_CONV workgroup dispatch for large ubatch

Tile tokens into 2D workgroups (32x16) to reduce workgroup launch
overhead at large ubatch sizes. Add vec4 fast path for nc=4 (common
d_conv size). Fixes PP performance degradation with ubatch > 512.

Ref: ggml-org/llama.cpp#18725

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: remove unused shared memory declaration in SSM_CONV

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Progeny Alpha <ProgenyAlpha@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-12 10:03:18 +01:00
Pascal de190154c8 New conversations now auto-select the first loaded model (#20403)
* webui: auto-select first loaded model for new conversations in router mode

* chore: update webui build output
2026-03-12 09:07:05 +01:00
Masashi Yoshimura 05039967da ggml-virtgpu: Fix some build commands (#20341) 2026-03-12 15:47:45 +08:00
Georgi Gerganov e4cff0956b metal : avoid divisions in bin kernel (#20426)
* metal : avoid modulus in bin kernel when not broadcasting

* metal : fix capture_started flag
2026-03-12 09:42:40 +02:00
Masato Nakasaka 4cc6eb158c ci: Setup self-hosted CI for Intel Linux Vulkan backend (#20154) 2026-03-12 06:43:22 +01:00
Jeff Bolz 246ffc4b05 vulkan: fix l2_norm epsilon handling (#20350) 2026-03-12 06:39:41 +01:00
Jeff Bolz aa429cf507 vulkan: fix OOB check in flash_attn_mask_opt (#20296) 2026-03-12 06:35:49 +01:00
Masato Nakasaka 5866e3bbc8 vulkan: Fix ErrorOutOfHostMemory on Intel GPU when loading large models with --no-mmap (#20059)
* Changed to reuse command buffers to fix crashing on Intel GPU

* Removed unused parameter

* Fixed compile error and minor mistake

* Fix logging

* Changing to use usage flag per command buffer

* fixed style

* added buffer reset

* Removed cmd_buffer_idx for reuse consistency

* Fixed style
2026-03-12 06:30:16 +01:00
lhez 0516e04bf9 opencl: use larger workgroup size for get_rows (#20316) 2026-03-11 22:03:27 -07:00
shaofeiqi 3d9ab225e7 opencl: add cumsum op (#18981)
* OpenCL: add CUMSUM op support

* remove unused argument

* opencl: refactor cumsum

* opencl: refactor

* opencl: refactor tmp buffer

* opencl: adjust max number of subgroups

* opencl: fix whitespace

* opencl: fix global size when cumsum the tmp buffer

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-03-11 22:03:07 -07:00
uvos d63aa398de hip: compile debug builds with -O2 on hip to avoid a compiler bug (#20392) 2026-03-12 10:37:10 +08:00
Mishusha a8304b4d27 common/parser: add GigaChatV3/3.1 models support (#19931)
Co-authored-by: Mishusha <pmv26021975@gmail.com>
2026-03-12 01:22:25 +01:00
DAN™ fdb17643d3 model : add support for Phi4ForCausalLMV (#20168)
* Add support for Phi4ForCausalLMV.

* Fix Phi-4 vision parity (correcting SigLIP2 patch-kernel export layout) and matching HF NaFlex resize behavior in mtmd.

* Rename contants + fix tokenizer label

* Clean-ups.

* Fix GGUF export.

* Set tokenizer.ggml.pre explicitly.

* Default vocab name rather than forcing it.

* Clean-ups.

* Fix indent.

* Fix subscriptable error.

* remov overcomplicated code path

* Clean-ups.

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-03-12 00:25:54 +01:00
Richard Davison 1eea6a2968 graph : add optional scale parameter to build_lora_mm [no ci] (#20427) 2026-03-12 00:22:49 +01:00
ddh0 4a748b8f15 common : fix --n-cpu-moe, --cpu-moe for models with fused gate + up (#20416) 2026-03-12 00:13:28 +01:00
Masashi Yoshimura f2ab047f27 ggml-webgpu: Add supports for GGML_OP_REPEAT (#20230)
* Add GGML_OP_REPEAT to webgpu backend.

* Add i16 support for GGML_OP_REPEAT.
2026-03-11 14:40:36 -07:00
Georgi Gerganov d28961d81e llama : enable chunked fused GDN path (#20340)
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 2068908975

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
2026-03-11 22:46:40 +02:00
Sigbjørn Skjæret f90bd1dd84 llama : whitespace cleanup (#20422) 2026-03-11 21:18:29 +01:00
Richard Davison 5eae9cb1d9 ggml : add NVFP4 quantization type support (#19769)
* WIP: add NVFP4 quantization support

* tests

* improve NVFP4 dot product implementation performance and fix bad super call

* typo

* Use nvfp4 kvalues

* vulkan : fix NVFP4 shader compilation by including kvalues_mxfp4 lookup table

* vulcal and perf fixes

* wip

* Fix metal

* fix vulcan

* Rename threshold & fix wrong scale

* Fix MOE

* Shelf backend implementations (CUDA, Metal, Vulkan, arch-specific SIMD)

Remove NVFP4 support from GPU backends and architecture-specific
optimized dot products. These should be added in separate PRs so
backend specialists can review them independently.

Reverted files:
- ggml-cuda: common.cuh, convert.cu, mmq.cu/cuh, mmvq.cu, vecdotq.cuh,
  quantize.cu/cuh, mma.cuh, ggml-cuda.cu, fattn-tile.cuh
- ggml-metal: ggml-metal.metal, ggml-metal-device.cpp, ggml-metal-impl.h,
  ggml-metal-ops.cpp
- ggml-vulkan: ggml-vulkan.cpp, all vulkan-shaders/*
- ggml-cpu arch: arm/quants.c, x86/quants.c, powerpc/quants.c, s390/quants.c

Core NVFP4 support (type definition, CPU fallback dot product,
quantization, dequantization, conversion) is retained.

* Fix arch-fallback.h: add NVFP4 generic fallback for all platforms

After shelving backend-specific SIMD implementations, the generic
CPU dot product needs to be aliased on ARM, x86, PowerPC, and s390
platforms that previously relied on arch-specific versions.

* quantize: add NVFP4 as a quantization type option

* Fix ggml_fp32_to_ue4m3: handle subnormal values

Previously, values with ue4m3_exp <= 0 were clamped to 0, causing
all small scales to underflow. This made NVFP4 quantization via
llama-quantize produce garbage (PPL = 5.8M) since typical transformer
weights have amax/6.0 in the range 0.001-0.01, which falls in the
UE4M3 subnormal range.

Now subnormals are properly encoded as man * 2^-9 (exp=0, man=1..7),
matching the decode path in ggml_ue4m3_to_fp32.

Result: NVFP4 requantization now produces PPL = 15.25 (vs F16 = 14.33),
comparable to Q4_1 (PPL = 15.81) at slightly lower BPW (4.70 vs 5.15).

* Restore ARM NEON NVFP4 dot product implementation

Restores the optimized ggml_vec_dot_nvfp4_q8_0 for ARM NEON using
vqtbl1q_s8 lookup and ggml_vdotq_s32 dot products.

tg128 performance: 4.37 t/s (generic) -> 13.66 t/s (NEON) = 3.1x speedup

* Optimize ARM NEON NVFP4 dot product: LUT + vpaddq + vfmaq

- Add ue4m3_scale_lut[128] to ggml-common.h replacing branch-heavy
  ggml_ue4m3_to_fp32() in the hot loop
- Use vpaddq_s32 for pairwise int32 reduction instead of vaddvq_s32
- Accumulate with vfmaq_f32 into float32x4_t vector accumulators

tg128: 8.1 -> 31.0 t/s (3.8x speedup, 77% of Q4_1 speed)

* ARM NEON NVFP4: rearrange q8 to match nibble layout

Alternative approach: rearrange q8 data to match the NVFP4 lo/hi
nibble layout instead of rearranging the looked-up NVFP4 values.
Eliminates vcombine_s8(vget_low, vget_low) shuffles.

Performance is equivalent (~18.5 t/s) - the bottleneck is the 2x
block overhead from QK=16 vs QK=32, not the shuffle instructions.

* CPU only backend 64 super-block layout

* cleanup

* Remove unused LUT

* int

* exclude NVFP4 from unsupported ops in metal build

* remove quantization for now

* store scales as native UE4M3, preserve original model bits when possible

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* correct comment

* format

* reduce duplication and cleanup

* Address comments

* move detection to prepare_tensors

* Use math instead of const

* Move

* fix comment

* Shelf quantize tests

* Rebase and move check

* cleanup

* lint

* Update gguf-py/gguf/scripts/gguf_convert_endian.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Use fallback quant config

* Simplify

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* organize

* Refactor

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* fix return type

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-11 21:02:54 +01:00
Georgi Gerganov 3ca19b0e9f benches : add nemotron super (#20420) 2026-03-11 21:39:40 +02:00
Daniel Bevenius eaf1d7930c llama : add support for Nemotron 3 Super (#20411)
* llama : add support for Nemotron 3 Super

This commit adds support for the Nemotron 3 Super model (120B.A12B)
enabling this model to be converted to GGUF format and run in llama.cpp.

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Matt Clayton <156335168+mattjcly@users.noreply.github.com>
2026-03-11 19:27:53 +01:00
Georgi Gerganov 76ea1c1c46 metal : fix capture_compute counter logic (#20410) 2026-03-11 18:38:22 +02:00
Aman Gupta bd1ec818e9 compare-llama-bench: check remotes as well (#20406) 2026-03-12 00:14:42 +08:00
Georgi Gerganov b541241104 metal : fix q5_k mul_mv register spill (#20399) 2026-03-11 16:25:27 +02:00
Georgi Gerganov c363256839 metal : add env var to trigger graph capture (#20398) 2026-03-11 16:25:10 +02:00
Neo Zhang ecac98ee53 [SYCL] Update SYCL.md for binary package for Windows (#20401)
* add download binary package

* update prefix
2026-03-11 22:21:22 +08:00
Ruben Ortlam 182acfe5c5 ci: disable coopmat on ubuntu-24-cmake-vulkan job (#20294) 2026-03-11 14:12:29 +01:00
Aldehir Rojas b5fe4559ae common/parser: use nlohmann::ordered_json to preserve parameter order (#20385) 2026-03-11 10:26:51 +01:00
Piotr Wilkin (ilintar) acb7c79069 common/parser: handle reasoning budget (#20297)
* v1

* Finished!

* Handlie cli

* Reasoning sampler

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Less explosive terminology :)

* Add utf-8 case and tests

* common : migrate reasoning budget sampler to common

* cont : clean up

* cont : expose state and allow passing as initial state

* cont : remove unused imports

* cont : update state machine doc string

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Alde Rojas <hello@alde.dev>
2026-03-11 10:26:12 +01:00
uvos 5f91b1d5d5 ggml-cuda: gdn use shared mem for HIP (#20366)
Suggested-by: Aman Gupta <amangupta052@gmail.com>
2026-03-11 13:06:19 +08:00
uvos 9ef7523ee9 cuda/hip: fix loop unrolling in ssm-conv (#20369) 2026-03-11 13:04:32 +08:00
Pascal 00de615345 Fix agentic mcp image single model (#20339)
* webui: fix MCP image attachments dropped during the agentic loop in single-model mode

* chore: update webui build output
2026-03-11 05:31:33 +01:00
Alessandro de Oliveira Faria (A.K.A.CABELO) e1a399992b vendor : update cpp-httplib to 0.37.0 (#20207) 2026-03-11 11:03:53 +08:00
Alessandro de Oliveira Faria (A.K.A.CABELO) 4f2f0a163d vendor : update miniaudio to 0.11.25 (#20209) 2026-03-11 11:01:56 +08:00
Neo Zhang 0cec84f999 fix op rope, add rope_back (#20293) 2026-03-11 09:53:34 +08:00
Neo Zhang b2e1427c9b fix for failed UT case: ACC, L2_NORM, UPSCALE, fused_glu, unary (#20283) 2026-03-11 09:53:05 +08:00
Vinicios Lugli 4d99d45084 model : qwen3vl reranker text support (#20332)
* model : fix qwen3vl reranker support

* Remove CLS_OUT

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-10 23:40:14 +01:00
ddh0 10e5b148b0 llama-quant : correct n_attention_wv usage (#20357)
* llama-quant : correct `n_attention_wv` usage

In #19770, I introduced a regression in the way the
`quantize_state_impl` counter values were initialized. I was
incrementing and using `n_attention_wv` in the same loop, when it should
have been fixed by the time we're deciding tensor types in
`llama_tensor_get_type_impl` (for `use_more_bits`).

I never observed a difference in any of [my
tests](https://github.com/ggml-org/llama.cpp/pull/19770#issuecomment-4000424712)
- it was only after @bartowski kindly pointed this out that I realized
it was incorrect. (Thanks!)

* simplify
2026-03-10 21:43:29 +02:00
Georgi Gerganov 90b2731894 ggml : bump RPC version (#20330) 2026-03-10 21:36:57 +02:00
Reese Levine aa2d278a11 ggml webgpu: faster normal quant and some k-quant matrix operations, better shader parameter handling (#20173)
* K quant speedup (#20)

* Basic JIT compilation for mul_mat, get_rows, and scale (#17)

* scale jit working

* preliminary working jit for getrows and mulmat, needs refining

* simplified mul_mat preprocessing switch statement

* get_rows fixes, mul_mat refinement

* formatted + last edits

* removed some extraneous prints

* fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish

* small fix

* some changes, working

* get_rows and mul_mat jit fixed and working

* Update formatting

* formatting

* Add header

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on all-encompassing shader library

* refactor argmax, set_rows

* Refactor all but flashattention, mat mul

* no gibberish, all k quants added, merged

* vec memory fix

* q6_k matching metal on my machine, tests passing

* Set tile size for q6_k separately

* Separate out fast shaders

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>

* Move towards writeBuffer for params

* Move away from multiple buffers for set_rows errors, remove host buffer for parameter buffers, minor cleanups

* Remove extra file

* Formatting

---------

Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
2026-03-10 09:14:27 -07:00
Piotr Wilkin (ilintar) 6c770d16ca Reduce level of content parser warning message to avoid log spam on non-debug verbosity (#20347) 2026-03-10 15:21:51 +01:00
Ray Xu 8d880ac012 examples : fix empty items in json_schema_to_grammar.py [no ci] (#19968)
* Fix logic for retrieving schema items in `json_schema_to_grammar.py`

If `schema['items']` is `{}` and `prefixItems not in schema', as `{}` is Falsy, the original code here will raise an error.

I think if `schema['items']` is `{}`, them items should just be `{}`

* Apply suggestion from @CISC

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Add tests for arrays with empty items

Add two unit tests to `tests/test-json-schema-to-grammar.cpp` that validate handling of arrays when 'items' is an empty schema and when 'prefixItems' is present alongside an empty 'items'. Both tests expect the same generated grammar, ensuring the JSON Schema->grammar conversion treats an empty 'items' schema (and the presence of 'prefixItems') correctly and covering this edge case.

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-10 14:38:18 +01:00
a3894281 0f1e9d14cc docs: update CPU backend ops to mark POOL_1D as supported (#20304) 2026-03-10 21:31:24 +08:00
Georgi Gerganov 1274fbee9e models : fix assert in mamba2 (cont) (#20335)
* models : fix assert in mamba2 (cont)

* cont : add n_group mod

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-03-10 15:00:08 +02:00
Georgi Gerganov a7b3dee7a5 server : make 2 checkpoints near the end of the prompt (#20288)
* server : make 2 checkpoints near the end of the prompt

* cont : adjust checkpoints
2026-03-10 14:28:23 +02:00
Sigbjørn Skjæret ec947d2b16 common : fix incorrect uses of stoul (#20313) 2026-03-10 11:40:26 +01:00
127 changed files with 9294 additions and 21102 deletions
+17
View File
@@ -469,6 +469,7 @@ jobs:
cd build
export GGML_VK_VISIBLE_DEVICES=0
export GGML_VK_DISABLE_F16=1
export GGML_VK_DISABLE_COOPMAT=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4800
@@ -1726,6 +1727,22 @@ jobs:
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-x64-linux-intel-vulkan:
runs-on: [self-hosted, Linux, X64, Intel]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-arm64-cpu-kleidiai:
runs-on: ubuntu-22.04-arm
+72
View File
@@ -0,0 +1,72 @@
# NVIDIA DGX Spark
## System info
```bash
uname --all
Linux spark-17ed 6.11.0-1016-nvidia #16-Ubuntu SMP PREEMPT_DYNAMIC Sun Sep 21 16:52:46 UTC 2025 aarch64 aarch64 aarch64 GNU/Linux
g++ --version
g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
nvidia-smi
Fri Mar 6 11:39:45 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GB10 On | 0000000F:01:00.0 Off | N/A |
| N/A 52C P0 13W / N/A | Not Supported | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
```
## ggml-org/nemotron-3-super-120b-GGUF
Model: https://huggingface.co/ggml-org/nemotron-3-super-120b-GGUF
- `llama-batched-bench`
main: n_kv_max = 303104, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, is_pp_shared = 0, is_tg_separate = 0, n_gpu_layers = 99, n_threads = 20, n_threads_batch = 20
| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 512 | 32 | 1 | 544 | 1.094 | 468.05 | 1.621 | 19.74 | 2.715 | 200.37 |
| 512 | 32 | 2 | 1088 | 1.463 | 700.16 | 2.437 | 26.26 | 3.900 | 279.01 |
| 512 | 32 | 4 | 2176 | 2.647 | 773.76 | 4.043 | 31.66 | 6.689 | 325.29 |
| 512 | 32 | 8 | 4352 | 5.291 | 774.14 | 6.151 | 41.62 | 11.442 | 380.37 |
| 512 | 32 | 16 | 8704 | 10.603 | 772.62 | 10.385 | 49.30 | 20.987 | 414.72 |
| 512 | 32 | 32 | 17408 | 21.231 | 771.69 | 18.235 | 56.16 | 39.466 | 441.09 |
| 4096 | 32 | 1 | 4128 | 5.340 | 767.05 | 1.616 | 19.81 | 6.956 | 593.47 |
| 4096 | 32 | 2 | 8256 | 10.673 | 767.55 | 2.454 | 26.08 | 13.127 | 628.94 |
| 4096 | 32 | 4 | 16512 | 21.348 | 767.46 | 4.072 | 31.44 | 25.420 | 649.57 |
| 4096 | 32 | 8 | 33024 | 42.714 | 767.15 | 6.277 | 40.78 | 48.991 | 674.08 |
| 4096 | 32 | 16 | 66048 | 85.385 | 767.54 | 10.596 | 48.32 | 95.981 | 688.14 |
| 4096 | 32 | 32 | 132096 | 170.819 | 767.32 | 18.619 | 55.00 | 189.437 | 697.31 |
| 8192 | 32 | 1 | 8224 | 10.690 | 766.32 | 1.619 | 19.76 | 12.310 | 668.10 |
| 8192 | 32 | 2 | 16448 | 21.382 | 766.24 | 2.467 | 25.94 | 23.850 | 689.65 |
| 8192 | 32 | 4 | 32896 | 42.782 | 765.92 | 4.098 | 31.23 | 46.881 | 701.69 |
| 8192 | 32 | 8 | 65792 | 85.582 | 765.77 | 6.368 | 40.20 | 91.951 | 715.52 |
| 8192 | 32 | 16 | 131584 | 171.066 | 766.21 | 10.774 | 47.52 | 181.840 | 723.62 |
| 8192 | 32 | 32 | 263168 | 342.140 | 766.19 | 18.969 | 53.98 | 361.109 | 728.78 |
- `llama-bench`
| model | size | params | backend | n_ubatch | fa | test | t/s |
| ----------------------- | ---------: | ---------: | ---------- | -------: | -: | --------------: | -------------------: |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | pp2048 | 768.84 ± 0.90 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | tg32 | 19.94 ± 0.16 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | pp2048 @ d4096 | 764.51 ± 0.50 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | tg32 @ d4096 | 19.95 ± 0.18 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | pp2048 @ d8192 | 759.53 ± 0.71 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | tg32 @ d8192 | 19.83 ± 0.18 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | pp2048 @ d16384 | 747.98 ± 1.58 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | tg32 @ d16384 | 19.84 ± 0.18 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | pp2048 @ d32768 | 724.40 ± 2.70 |
| nemotron 120B.A12B Q4_K | 65.10 GiB | 120.67 B | CUDA | 2048 | 1 | tg32 @ d32768 | 19.45 ± 0.18 |
build: 04a65daab (8268)
+2
View File
@@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
preset.cpp
preset.h
regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h
sampling.cpp
sampling.h
+33 -4
View File
@@ -2427,11 +2427,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
);
}
if (split_arg.size() == 1) {
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoull(split_arg[0]) * 1024*1024);
return;
}
for (size_t i = 0; i < split_arg.size(); i++) {
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
params.fit_params_target[i] = std::stoull(split_arg[i]) * 1024*1024;
}
}
).set_env("LLAMA_ARG_FIT_TARGET"));
@@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
auto parsed = json::parse(value);
for (const auto & item : parsed.items()) {
if (item.key() == "enable_thinking") {
LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
"Use --reasoning on / --reasoning off instead.\n");
}
params.default_template_kwargs[item.key()] = item.value().dump();
}
}
@@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.reasoning_format = common_reasoning_format_from_name(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
add_opt(common_arg(
{"-rea", "--reasoning"}, "[on|off|auto]",
"Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
[](common_params & params, const std::string & value) {
if (is_truthy(value)) {
params.enable_reasoning = 1;
params.default_template_kwargs["enable_thinking"] = "true";
} else if (is_falsey(value)) {
params.enable_reasoning = 0;
params.default_template_kwargs["enable_thinking"] = "false";
} else if (is_autoy(value)) {
params.enable_reasoning = -1;
} else {
throw std::invalid_argument(
string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
add_opt(common_arg(
{"--reasoning-budget"}, "N",
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
[](common_params & params, int value) {
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
if (value < -1) { throw std::invalid_argument("invalid value"); }
params.reasoning_budget = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
add_opt(common_arg(
{"--reasoning-budget-message"}, "MESSAGE",
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
[](common_params & params, const std::string & value) {
params.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
+3 -1
View File
@@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
return p.reasoning(p.until(end)) + end;
// However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
+24 -24
View File
@@ -6,7 +6,7 @@
#include <nlohmann/json.hpp>
using json = nlohmann::ordered_json;
using ordered_json = nlohmann::ordered_json;
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
int count = 0;
@@ -68,7 +68,7 @@ static int json_brace_depth(const std::string & s) {
// JSON-escape a string and return the inner content (without surrounding quotes).
static std::string escape_json_string_inner(const std::string & s) {
std::string escaped = json(s).dump();
std::string escaped = ordered_json(s).dump();
if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') {
return escaped.substr(1, escaped.size() - 2);
}
@@ -309,7 +309,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
if (arg_count > 0) {
arg_entry = ",";
}
arg_entry += json(trim(node.text)).dump() + ":";
arg_entry += ordered_json(trim(node.text)).dump() + ":";
++arg_count;
auto & target = args_target();
@@ -343,7 +343,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
// Try to parse as JSON value (number, bool, null, object, array)
try {
json parsed = json::parse(value_content);
ordered_json parsed = ordered_json::parse(value_content);
if (parsed.is_string()) {
// Don't add closing quote yet (added by arg_close) for monotonic streaming
std::string escaped = parsed.dump();
@@ -408,7 +408,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
common_peg_parser common_chat_peg_builder::standard_constructed_tools(
const std::map<std::string, std::string> & markers,
const nlohmann::json & tools,
const ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls) {
if (!tools.is_array() || tools.empty()) {
@@ -439,7 +439,7 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
// Build argument parsers
auto args = eps();
@@ -479,8 +479,8 @@ common_peg_parser common_chat_peg_builder::standard_constructed_tools(
// Python-style tool calls: name(arg1="value1", arg2=123)
// Used only by LFM2 for now, so we don't merge it into autoparser
common_peg_parser common_chat_peg_builder::python_style_tool_calls(
const nlohmann::json & tools,
bool parallel_tool_calls) {
const ordered_json & tools,
bool parallel_tool_calls) {
if (!tools.is_array() || tools.empty()) {
return eps();
}
@@ -493,7 +493,7 @@ common_peg_parser common_chat_peg_builder::python_style_tool_calls(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto args = eps();
if (params.contains("properties") && !params["properties"].empty()) {
@@ -555,11 +555,11 @@ static std::pair<std::string, std::string> parse_key_spec(const std::string & ke
// Mode 1: function_is_key — parse {"function_name": {...}}
common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
const nlohmann::json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
const ordered_json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
auto tool_choices = choice();
@@ -569,7 +569,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
// Build inner object fields
std::vector<common_peg_parser> inner_fields;
@@ -634,11 +634,11 @@ common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key(
// Mode 2: Nested keys (dot notation like "function.name")
common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
const ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key) {
auto tool_choices = choice();
@@ -655,7 +655,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
@@ -706,7 +706,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
// Mode 3: Flat keys with optional ID fields and parameter ordering
common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
const nlohmann::json & tools,
const ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
@@ -723,7 +723,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
}
const auto & function = tool_def.at("function");
std::string name = function.at("name");
nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object();
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
@@ -791,7 +791,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
common_peg_parser common_chat_peg_builder::standard_json_tools(
const std::string & section_start,
const std::string & section_end,
const nlohmann::json & tools,
const ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls,
const std::string & name_key,
+15 -15
View File
@@ -94,7 +94,7 @@ class common_chat_peg_builder : public common_peg_parser_builder {
// parameters_order: order in which JSON fields should be parsed
common_peg_parser standard_json_tools(const std::string & section_start,
const std::string & section_end,
const nlohmann::json & tools,
const nlohmann::ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls,
const std::string & name_key = "",
@@ -108,30 +108,30 @@ class common_chat_peg_builder : public common_peg_parser_builder {
// Legacy-compatible helper for building XML/tagged style tool calls
// Used by tests and manual parsers
common_peg_parser standard_constructed_tools(const std::map<std::string, std::string> & markers,
const nlohmann::json & tools,
const nlohmann::ordered_json & tools,
bool parallel_tool_calls,
bool force_tool_calls);
// Helper for Python-style function call format: name(arg1="value1", arg2=123)
// Used by LFM2 and similar templates
common_peg_parser python_style_tool_calls(const nlohmann::json & tools,
bool parallel_tool_calls);
common_peg_parser python_style_tool_calls(const nlohmann::ordered_json & tools,
bool parallel_tool_calls);
private:
// Implementation helpers for standard_json_tools — one per JSON tool call layout mode
common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_function_is_key(const nlohmann::ordered_json & tools,
const std::string & args_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_nested_keys(const nlohmann::ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
const std::string & gen_call_id_key);
common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools,
common_peg_parser build_json_tools_flat_keys(const nlohmann::ordered_json & tools,
const std::string & effective_name_key,
const std::string & effective_args_key,
const std::string & call_id_key,
+100 -6
View File
@@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
data.supports_thinking = true;
data.supports_thinking = true;
data.thinking_start_tag = "[THINK]";
data.thinking_end_tag = "[/THINK]";
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
@@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
const autoparser::templates_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = {
"<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>",
@@ -1350,6 +1354,77 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
return data;
}
static common_chat_params common_chat_params_init_gigachat_v3(
const common_chat_template & tmpl,
const autoparser::templates_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = false;
data.preserved_tokens = {
"<|message_sep|>\n\n",
"<|role_sep|>\n",
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
auto tool_call_start_prefix = "<|message_sep|>\n\nfunction call<|role_sep|>\n";
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// Build a choice of all available tools
auto tool_choice = p.choice();
for (const auto & tool : inputs.tools) {
const auto & function = tool.at("function");
std::string name = function.at("name");
const auto & schema = function.at("parameters");
auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\"");
auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
auto tool_open = p.tool_open(p.literal("{") << tool_name);
tool_choice |= p.rule("tool-" + name, tool_open << "," << tool_args << "}");
}
// Define the tool call structure
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
auto max_calls = 1; // parallel toolcalls are not supported
auto tool_call = p.rule("tool-call", p.literal(tool_call_start_prefix) + tool_choice);
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
return p.content(p.until("<|message_sep|>\n\n")) << tool_calls;
}
// Content only parser
include_grammar = false;
return p.content(p.rest());
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, tool_call_start_prefix}
};
}
return data;
}
namespace workaround {
static void map_developer_role_to_system(json & messages) {
@@ -1521,12 +1596,31 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
return common_chat_params_init_lfm2(tmpl, params);
}
// GigaChatV3 format detection
if (src.find("<|role_sep|>") != std::string::npos &&
src.find("<|message_sep|>") != std::string::npos &&
src.find("<|function_call|>") == std::string::npos
) {
LOG_DBG("Using specialized template: GigaChatV3\n");
return common_chat_params_init_gigachat_v3(tmpl, params);
}
try {
LOG_DBG("Using differential autoparser\n");
struct autoparser::autoparser autoparser;
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
}
return auto_params;
} catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
@@ -1620,8 +1714,8 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) :
src_parser;
if (src_parser.empty()) {
LOG_WRN("No parser definition detected, assuming pure content parser.");
if (src_parser.empty()) {
LOG_DBG("No parser definition detected, assuming pure content parser.");
}
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
+2
View File
@@ -213,6 +213,8 @@ struct common_chat_params {
bool grammar_lazy = false;
bool thinking_forced_open = false;
bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>"
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
+11 -1
View File
@@ -235,6 +235,14 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
bool backend_sampling = false;
bool has_logit_bias() const {
@@ -536,7 +544,9 @@ struct common_params {
bool use_jinja = true; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
int reasoning_budget = -1;
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
@@ -916,7 +926,7 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
// MoE utils
//
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate|gate_up)_(ch|)exps";
inline std::string llm_ffn_exps_block_regex(int idx) {
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
+1 -1
View File
@@ -790,7 +790,7 @@ public:
} else if (target.is_array()) {
size_t sel_index;
try {
sel_index = std::stoul(sel);
sel_index = std::stoull(sel);
} catch (const std::invalid_argument & e) {
sel_index = target.size();
}
+219
View File
@@ -0,0 +1,219 @@
#include "reasoning-budget.h"
#include "common.h"
#include "unicode.h"
#include "log.h"
#include <cmath>
#include <cstdint>
#include <string>
#include <vector>
struct token_matcher {
std::vector<llama_token> tokens;
size_t pos = 0;
bool advance(llama_token token) {
if (tokens.empty()) {
return false;
}
if (token == tokens[pos]) {
pos++;
if (pos >= tokens.size()) {
pos = 0;
return true;
}
} else {
pos = 0;
if (token == tokens[0]) {
pos = 1;
}
}
return false;
}
void reset() { pos = 0; }
};
struct common_reasoning_budget_ctx {
const llama_vocab * vocab;
token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;
int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget
common_reasoning_budget_state state;
// for forcing
size_t force_pos; // next position in forced_tokens to force
};
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
return "reasoning-budget";
}
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
switch (ctx->state) {
case REASONING_BUDGET_IDLE:
{
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
case REASONING_BUDGET_COUNTING:
case REASONING_BUDGET_WAITING_UTF8:
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
break;
}
bool utf8_complete = true;
if (ctx->vocab != nullptr) {
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
utf8_complete = common_utf8_is_complete(piece);
}
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
}
}
}
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
break;
case REASONING_BUDGET_DONE:
break;
}
}
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
if (ctx->state != REASONING_BUDGET_FORCING) {
// passthrough — don't modify logits
return;
}
if (ctx->force_pos >= ctx->forced_tokens.size()) {
return;
}
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
// set all logits to -inf except the forced token
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
}
}
// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
ctx->state = REASONING_BUDGET_IDLE;
ctx->remaining = ctx->budget;
ctx->start_matcher.reset();
ctx->end_matcher.reset();
ctx->force_pos = 0;
}
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->budget,
ctx->state);
}
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
delete (common_reasoning_budget_ctx *) smpl->ctx;
}
static struct llama_sampler_i common_reasoning_budget_i = {
/* .name = */ common_reasoning_budget_name,
/* .accept = */ common_reasoning_budget_accept,
/* .apply = */ common_reasoning_budget_apply,
/* .reset = */ common_reasoning_budget_reset,
/* .clone = */ common_reasoning_budget_clone,
/* .free = */ common_reasoning_budget_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
}
return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx {
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
}
);
}
+41
View File
@@ -0,0 +1,41 @@
#pragma once
#include "llama.h"
#include <cstdint>
#include <vector>
enum common_reasoning_budget_state {
REASONING_BUDGET_IDLE, // waiting for start sequence
REASONING_BUDGET_COUNTING, // counting down tokens
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
REASONING_BUDGET_DONE, // passthrough forever
};
// Creates a reasoning budget sampler that limits token generation inside a
// reasoning block (e.g. between <think> and </think>).
//
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
// IDLE: passthrough, watching for start_tokens sequence
// COUNTING: counting down remaining tokens, watching for natural end_tokens
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
// note: COUNTING with budget <= 0 is promoted to FORCING
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state);
+12
View File
@@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "reasoning-budget.h"
#include <algorithm>
#include <cmath>
@@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
}
if (params.has_logit_bias()) {
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
}
+17 -1
View File
@@ -1,8 +1,10 @@
#include "unicode.h"
#include <algorithm>
#include <cassert>
#include <stdexcept>
#include <vector>
#include <string>
#include <vector>
// implementation adopted from src/unicode.cpp
@@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
return utf8_parse_result(utf8_parse_result::INVALID);
}
bool common_utf8_is_complete(const std::string & s) {
if (s.empty()) {
return true;
}
for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
unsigned char c = s[s.size() - i];
if ((c & 0xC0) != 0x80) {
int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
return i >= expected;
}
}
return false;
}
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result;
for (size_t i = 0; i < cps.size(); ++i) {
+3
View File
@@ -20,6 +20,9 @@ struct utf8_parse_result {
// Returns 0 for invalid first bytes
size_t common_utf8_sequence_length(unsigned char first_byte);
// Check if a string ends with a complete UTF-8 sequence.
bool common_utf8_is_complete(const std::string & s);
// Parse a single UTF-8 codepoint from input
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);
+322 -13
View File
@@ -144,6 +144,7 @@ class ModelBase:
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self._is_nvfp4 = False
# Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
@@ -271,6 +272,9 @@ class ModelBase:
return tensors
def dequant_model(self):
if self._is_nvfp4:
return # NVFP4 weights are repacked in _generate_nvfp4_tensors
tensors_to_remove: list[str] = []
new_tensors: dict[str, Callable[[], Tensor]] = {}
@@ -516,6 +520,13 @@ class ModelBase:
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip NVFP4 auxiliary tensors (handled in _generate_nvfp4_tensors)
if self._is_nvfp4:
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
return []
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
return []
new_name = self.map_tensor_name(name)
# Handle gate/up expert tensor fusion if enabled
@@ -551,9 +562,135 @@ class ModelBase:
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
return ()
@staticmethod
def _nvfp4_pack(weight: Tensor, scale: Tensor) -> tuple[np.ndarray, list[int]]:
"""Repack NVFP4 ModelOpt tensors into ggml super-block layout.
Preserves original E4M3 scale bits as UE4M3 (strip sign bit).
The per-tensor scale2 factor is stored as a separate tensor and applied at inference time via ggml_mul().
Returns (raw_data, logical_shape)."""
out_features = weight.shape[0]
n_blocks = scale.shape[1]
# Unpack ModelOpt nibble-packed weights
w = weight.reshape(out_features, n_blocks, 8)
vals = torch.stack([w & 0x0F, w >> 4], dim=-1).reshape(out_features, n_blocks, 16)
# Preserve original E4M3 scale bits as UE4M3 (strip sign bit)
d_ue = scale.view(torch.uint8).numpy().reshape(out_features, n_blocks) & 0x7F
qs = (vals[:, :, :8] | (vals[:, :, 8:] << 4)).to(torch.uint8).numpy()
# Pack into super-blocks: [4 UE4M3 scales, 32 qs bytes] = 36 bytes per 64 elements
n_super = n_blocks // 4
d_grouped = d_ue.reshape(out_features, n_super, 4)
qs_grouped = qs.reshape(out_features, n_super, 4, 8).reshape(out_features, n_super, 32)
raw = np.concatenate([d_grouped, qs_grouped], axis=-1).reshape(out_features, n_super * 36)
return raw, [out_features, n_super * 64]
@staticmethod
def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool:
return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6
def _repack_nvfp4(self, new_name: str, weight: Tensor, scale: Tensor, scale2: Tensor):
raw, shape = self._nvfp4_pack(weight, scale)
logger.info(f"Repacked {new_name} with shape {shape} and quantization NVFP4")
self.gguf_writer.add_tensor(new_name, raw, raw_dtype=gguf.GGMLQuantizationType.NVFP4)
# Emit per-tensor scale2 as a separate F32 tensor when non-trivial
if not self._nvfp4_scale2_is_trivial(scale2):
scale2_f32 = scale2.float().numpy().flatten()
scale_name = new_name.replace(".weight", ".scale")
logger.info(f" + {scale_name} (per-tensor NVFP4 scale2, shape [{scale2_f32.size}])")
self.gguf_writer.add_tensor(scale_name, scale2_f32)
def _generate_nvfp4_tensors(self):
# Per-layer expert merging to avoid holding all experts in memory
expert_blocks: dict[tuple[int, str], list[tuple[int, np.ndarray]]] = {}
expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
expert_shapes: dict[tuple[int, str], list[int]] = {}
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
for name in list(self.model_tensors.keys()):
if not name.endswith(".weight"):
continue
scale_name = name.replace(".weight", ".weight_scale")
scale2_name = name.replace(".weight", ".weight_scale_2")
if scale_name not in self.model_tensors:
continue
# Force eager materialization of lazy tensors
weight = LazyTorchTensor.to_eager(self.model_tensors[name]())
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())
# Check if this is a per-expert tensor
m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
if m:
expert_id = int(m.group(1))
proj_type = m.group(2)
bid_m = re.search(r'\.layers\.(\d+)\.', name)
bid = int(bid_m.group(1)) if bid_m else 0
key = (bid, proj_type)
raw, shape = self._nvfp4_pack(weight, scale)
if key not in expert_blocks:
expert_blocks[key] = []
expert_scales[key] = []
expert_shapes[key] = shape
expert_blocks[key].append((expert_id, raw.copy()))
# Collect per-expert scale2 (scalar per expert)
expert_scales[key].append((expert_id, float(scale2.float().sum())))
# Flush when all experts for this (layer, proj) are collected
if n_experts > 0 and len(expert_blocks[key]) >= n_experts:
self._flush_nvfp4_experts(key, expert_blocks, expert_scales, expert_shapes, bid, proj_type)
else:
new_name = self.map_tensor_name(name)
self._repack_nvfp4(new_name, weight, scale, scale2)
# Flush any remaining experts (fallback if n_experts was unknown)
for (bid, proj_type) in list(expert_blocks.keys()):
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)
def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
experts = expert_blocks.pop(key)
scales = expert_scales.pop(key)
shape = expert_shapes.pop(key)
experts.sort(key=lambda x: x[0])
merged = np.stack([e[1] for e in experts], axis=0)
merged_name = f"model.layers.{bid}.mlp.experts.{proj_type}.weight"
new_name = self.map_tensor_name(merged_name)
logger.info(f"Repacked {new_name} with shape [{len(experts)}, {shape[0]}, {shape[1]}] and quantization NVFP4")
self.gguf_writer.add_tensor(new_name, merged, raw_dtype=gguf.GGMLQuantizationType.NVFP4)
# Emit per-expert scale2 tensor if any expert has non-trivial scale2
scales.sort(key=lambda x: x[0])
scale_vals = np.array([s[1] for s in scales], dtype=np.float32)
if not np.allclose(scale_vals, 1.0, atol=1e-6):
scale_name = new_name.replace(".weight", ".scale")
logger.info(f" + {scale_name} (per-expert NVFP4 scale2, shape [{len(scales)}])")
self.gguf_writer.add_tensor(scale_name, scale_vals)
del experts, merged
def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
quant_config_file = self.dir_model / "hf_quant_config.json"
if not quant_algo and quant_config_file.is_file():
with open(quant_config_file, "r", encoding="utf-8") as f:
quant_algo = (json.load(f).get("quantization") or {}).get("quant_algo")
self._is_nvfp4 = quant_algo == "NVFP4"
self.dequant_model()
# NVFP4 weights are repacked and written directly to gguf_writer
if self._is_nvfp4:
self._generate_nvfp4_tensors()
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
@@ -4303,6 +4440,14 @@ class Qwen2MoeModel(TextModel):
# process the experts separately
name = name.replace("language_model.", "") # InternVL
# NVFP4 expert weights are handled in _generate_nvfp4_tensors
if self._is_nvfp4 and "experts" in name:
if name.endswith((".weight", ".weight_scale", ".weight_scale_2", ".input_scale")):
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
return
if not name.endswith(".weight"):
return
# handle aggregated expert tensors
# GGUF stores dimensions reversed from PyTorch, so:
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
@@ -4390,15 +4535,31 @@ class Qwen3Model(Qwen2Model):
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
if self._is_qwen3_reranker():
self._find_rerank_config()
def _is_qwen3_reranker(self) -> bool:
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()
name_hints = [
str(self.dir_model.name),
str(self.hparams.get("_name_or_path", "")),
str(self.hparams.get("model_type", "")),
str(self.origin_hf_arch or ""),
]
name_hints = [hint.lower() for hint in name_hints if hint]
if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower():
return True
if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints):
return True
return "sequenceclassification" in (self.origin_hf_arch or "").lower()
def set_vocab(self):
# deal with intern-s1-mini
@@ -4901,7 +5062,7 @@ class Phi2Model(TextModel):
self.gguf_writer.add_add_bos_token(False)
@ModelBase.register("Phi3ForCausalLM")
@ModelBase.register("Phi3ForCausalLM", "Phi4ForCausalLMV")
class Phi3MiniModel(TextModel):
model_arch = gguf.MODEL_ARCH.PHI3
@@ -5076,6 +5237,129 @@ class Phi3MiniModel(TextModel):
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith(("model.vision_tower.", "vision_tower.", "model.mm_projector.", "mm_projector.")):
return
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Phi4ForCausalLMV")
class Phi4VisionMmprojModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.vision_total_layers = int(self.find_vparam(self.n_block_keys))
if self.vision_total_layers < 2:
raise ValueError(
f"Phi-4 vision mmproj conversion requires at least 2 vision layers, got {self.vision_total_layers}"
)
# Phi-4 uses SigLIP2 hidden_states[-2], so export one fewer encoder block and
# drop post-layernorm/head weights. This makes the GGUF runtime output match
# the feature map consumed by the patched siglip.cpp Phi-4 projector path.
self.vision_export_layers = self.vision_total_layers - 1
self.vision_last_layer_idx = self.vision_total_layers - 1
for key in self.n_block_keys:
if key in self.hparams_vision:
self.hparams_vision[key] = self.vision_export_layers
break
self.block_count = self.vision_export_layers
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
patch_size = self.preprocessor_config.get("patch_size")
if patch_size is None:
raise KeyError("Phi-4 vision mmproj conversion requires patch_size in preprocessor_config.json")
self.hparams_vision["patch_size"] = patch_size
pos_emb_name = next(
(
name for name in self.model_tensors
if name.endswith("vision_model.embeddings.position_embedding.weight")
),
None,
)
if pos_emb_name is None:
raise KeyError("Phi-4 vision mmproj conversion could not find position_embedding.weight")
pos_emb_shape = self.model_tensors[pos_emb_name]().shape
base_grid_tokens = int(pos_emb_shape[0])
grid_side = math.isqrt(base_grid_tokens)
if grid_side * grid_side != base_grid_tokens:
raise ValueError(f"Unexpected Phi-4 position embedding shape: {tuple(pos_emb_shape)}")
self.hparams_vision["image_size"] = grid_side * patch_size
min_num_patches = self.preprocessor_config.get("min_num_patches", self.global_config.get("min_num_patches"))
max_num_patches = self.preprocessor_config.get("max_num_patches", self.global_config.get("max_num_patches"))
if min_num_patches is None or max_num_patches is None:
raise KeyError("Phi-4 vision mmproj conversion requires min_num_patches and max_num_patches")
self.min_pixels = int(min_num_patches) * patch_size * patch_size
self.max_pixels = int(max_num_patches) * patch_size * patch_size
def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PHI4)
self.gguf_writer.add_vision_min_pixels(self.min_pixels)
self.gguf_writer.add_vision_max_pixels(self.max_pixels)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith(("model.vision_tower.vision_tower.", "vision_tower.")):
if ".vision_model.head." in name:
return
new_name = name.replace("model.vision_tower.vision_tower.", "vision_tower.")
if ".vision_model.post_layernorm." in new_name:
return
if bid is not None and bid == self.vision_last_layer_idx:
return
if new_name.endswith("vision_model.embeddings.patch_embedding.weight"):
assert self.hparams_vision is not None
if data_torch.ndim != 2:
raise ValueError(f"Unexpected Phi-4 patch embedding shape: {tuple(data_torch.shape)}")
patch_area = self.hparams_vision["patch_size"] ** 2
in_features = data_torch.shape[1]
if in_features % patch_area != 0:
raise ValueError(
f"Phi-4 patch embedding input dim {in_features} is not divisible by patch area {patch_area}"
)
num_channels = in_features // patch_area
patch_size = self.hparams_vision["patch_size"]
data_torch = data_torch.view(data_torch.shape[0], patch_size, patch_size, num_channels)
data_torch = data_torch.permute(0, 3, 1, 2)
yield from super().modify_tensors(data_torch, new_name, bid)
return
if name.startswith(("model.mm_projector.", "mm_projector.")):
local_name = name
local_name = local_name.replace("model.mm_projector.", "")
local_name = local_name.replace("mm_projector.", "")
if not (local_name.startswith("0.") or local_name.startswith("2.")):
return
suffix = ".bias" if local_name.endswith(".bias") else ".weight"
mm_idx = int(local_name.split(".", maxsplit=1)[0])
yield (self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, mm_idx, suffix=suffix), data_torch)
return
return
@ModelBase.register("PhiMoEForCausalLM")
class PhiMoeModel(Phi3MiniModel):
@@ -9727,20 +10011,35 @@ class NemotronHModel(GraniteHybridModel):
# M: Mamba2, *: Attention, -: MLP
# MoE:
# M: Mamba2, *: Attention, E: Expert
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")]
pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type")
if pattern is None:
self._ssm_layers = []
self._mlp_layers = []
elif isinstance(pattern, str):
self._ssm_layers = [i for i, val in enumerate(pattern) if val == "M"]
self._mlp_layers = [i for i, val in enumerate(pattern) if val == ("E" if self.is_moe else "-")]
else:
self._ssm_layers = [i for i, val in enumerate(pattern) if val == "mamba"]
self._mlp_layers = [i for i, val in enumerate(pattern) if val == "moe"]
def get_attn_layers(self):
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type")
if pattern is None:
return []
assert len(pattern) == self.block_count, f"Mismatch between pattern ({len(pattern)}) and block_count ({self.block_count})!"
if isinstance(pattern, str):
return [i for i, val in enumerate(pattern) if val == "*"]
return [i for i, val in enumerate(pattern) if val == "attention"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_key_length(self.head_dim)
self.gguf_writer.add_value_length(self.head_dim)
head_dim = self.head_dim
if head_dim is None:
raise ValueError("Could not find the attention head dim in config")
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)
# Set feed_forward_length
# NOTE: This will trigger an override warning. This is preferable to
@@ -9768,6 +10067,9 @@ class NemotronHModel(GraniteHybridModel):
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
if (latent_size := self.hparams.get("moe_latent_size")) is not None:
self.gguf_writer.add_moe_latent_size(latent_size)
def set_vocab(self):
super().set_vocab()
@@ -9787,6 +10089,13 @@ class NemotronHModel(GraniteHybridModel):
name = name[len("language_model."):]
if self.is_moe and bid is not None:
# Skip Multi-Token Prediction (MTP) tensors. These are used for
# for speculative decoding but we don't include them in this model
# conversion. See https://github.com/ggml-org/llama.cpp/pull/18886
if name.startswith("mtp."):
logger.info(f"gguf: Skipping MTP (Speculative) layer: {name}")
return
if name.endswith("mixer.gate.e_score_correction_bias"):
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
yield from ModelBase.modify_tensors(self, data_torch, new_name, bid)
+27 -17
View File
@@ -382,17 +382,27 @@ use 1 SYCL GPUs: [0] with Max compute units:512
## Windows
### I. Setup Environment
1. Install GPU driver
### Install GPU driver
Intel GPU drivers instructions guide and download page can be found here: [Get Intel GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html).
2. Install Visual Studio
### Option 1: download the binary package directly
Download the binary package for Windows from: https://github.com/ggml-org/llama.cpp/releases.
Extract the package to local folder, run the llama tools directly. Refer to [Run the inference](#iii-run-the-inference-1).
Note, the package includes the SYCL running time and all depended dll files, no need to install oneAPI package and activte them.
### Option 2: build locally from the source code.
#### I. Setup environment
1. Install Visual Studio
If you already have a recent version of Microsoft Visual Studio, you can skip this step. Otherwise, please refer to the official download page for [Microsoft Visual Studio](https://visualstudio.microsoft.com/).
3. Install Intel® oneAPI Base toolkit
2. Install Intel® oneAPI Base toolkit
SYCL backend depends on:
- Intel® oneAPI DPC++/C++ compiler/running-time.
@@ -443,25 +453,25 @@ Output (example):
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Iris(R) Xe Graphics 1.3 [1.3.28044]
```
4. Install build tools
3. Install build tools
a. Download & install cmake for Windows: https://cmake.org/download/ (CMake can also be installed from Visual Studio Installer)
b. The new Visual Studio will install Ninja as default. (If not, please install it manually: https://ninja-build.org/)
### II. Build llama.cpp
#### II. Build llama.cpp
You could download the release package for Windows directly, which including binary files and depended oneAPI dll files.
Choose one of following methods to build from source code.
#### 1. Script
##### Option 1: Script
```sh
.\examples\sycl\win-build-sycl.bat
```
#### 2. CMake
##### Option 2: CMake
On the oneAPI command line window, step into the llama.cpp main directory and run the following:
@@ -490,7 +500,7 @@ cmake --preset x64-windows-sycl-debug
cmake --build build-x64-windows-sycl-debug -j --target llama-completion
```
#### 3. Visual Studio
##### Option 3: Visual Studio
You have two options to use Visual Studio to build llama.cpp:
- As CMake Project using CMake presets.
@@ -500,7 +510,7 @@ You have two options to use Visual Studio to build llama.cpp:
All following commands are executed in PowerShell.
##### - Open as a CMake Project
###### - Open as a CMake Project
You can use Visual Studio to open the `llama.cpp` folder directly as a CMake project. Before compiling, select one of the SYCL CMake presets:
@@ -515,7 +525,7 @@ You can use Visual Studio to open the `llama.cpp` folder directly as a CMake pro
cmake --build build --config Release -j --target llama-completion
```
##### - Generating a Visual Studio Solution
###### - Generating a Visual Studio Solution
You can use Visual Studio solution to build and work on llama.cpp on Windows. You need to convert the CMake Project into a `.sln` file.
@@ -603,7 +613,7 @@ found 2 SYCL devices:
```
#### Choose level-zero devices
##### Choose level-zero devices
|Chosen Device ID|Setting|
|-|-|
@@ -611,7 +621,7 @@ found 2 SYCL devices:
|1|`set ONEAPI_DEVICE_SELECTOR="level_zero:1"`|
|0 & 1|`set ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"` or `set ONEAPI_DEVICE_SELECTOR="level_zero:*"`|
#### Execute
##### Execute
Choose one of following methods to run.
@@ -669,7 +679,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
## Environment Variable
#### Build
### Build
| Name | Value | Function |
|--------------------|---------------------------------------|---------------------------------------------|
@@ -684,7 +694,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
1. FP32 or FP16 have different performance impact to LLM. Recommended to test them for better prompt processing performance on your models. You need to rebuild the code after change `GGML_SYCL_F16=OFF/ON`.
#### Runtime
### Runtime
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
@@ -777,7 +787,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
```
### **GitHub contribution**:
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
Please add the `[SYCL]` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.
## TODO
+4 -2
View File
@@ -55,7 +55,8 @@ LLAMA_MAC_BUILD=$PWD/build/ggml-virtgpu-backend
cmake -S . -B $LLAMA_MAC_BUILD \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=ON \
-DGGML_REMOTINGBACKEND=ONLY \
-DGGML_VIRTGPU=ON \
-DGGML_VIRTGPU_BACKEND=ONLY \
-DGGML_METAL=ON
TARGETS="ggml-metal"
@@ -71,6 +72,7 @@ cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $EXTRA_TARGETS
```bash
# Build virglrenderer with APIR support
mkdir virglrenderer
cd virglrenderer
git clone https://gitlab.freedesktop.org/kpouget/virglrenderer -b main-macos src
cd src
@@ -95,7 +97,7 @@ mkdir llama.cpp
git clone https://github.com/ggml-org/llama.cpp.git src
cd src
LLAMA_LINUX_BUILD=$PWD//build-virtgpu
LLAMA_LINUX_BUILD=$PWD/build-virtgpu
cmake -S . -B $LLAMA_LINUX_BUILD \
-DGGML_VIRTGPU=ON
+10 -10
View File
@@ -23,7 +23,7 @@ Legend:
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@@ -31,7 +31,7 @@ Legend:
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@@ -47,7 +47,7 @@ Legend:
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@@ -64,7 +64,7 @@ Legend:
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | ✅ | ✅ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@@ -76,17 +76,17 @@ Legend:
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_1D | ❌ | ❌ | | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | | ❌ | ❌ |
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
@@ -97,13 +97,13 @@ Legend:
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
+1689 -6836
View File
File diff suppressed because it is too large Load Diff
+1137 -12992
View File
File diff suppressed because it is too large Load Diff
+14 -14
View File
@@ -5023,20 +5023,20 @@
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[1024,12,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[2000,10,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGMAX","type=f32,ne=[5438,3,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,1,1,1],v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[2,1,1,1],v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,2,1,1],v=0","support","0","no","WebGPU"
Can't render this file because it is too large.
+1 -1
View File
@@ -633,7 +633,7 @@ class SchemaConverter:
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
items = schema.get('items', schema.get('prefixItems'))
if isinstance(items, list):
return self._add_rule(
rule_name,
+6 -1
View File
@@ -8,7 +8,12 @@ extern "C" {
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 6
#define RPC_PROTO_PATCH_VERSION 0
#define RPC_PROTO_PATCH_VERSION 1
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
#endif
#define GGML_RPC_MAX_SERVERS 16
// backend API
+5 -1
View File
@@ -427,7 +427,8 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_COUNT = 40,
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_COUNT = 41,
};
// precision
@@ -463,6 +464,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors
};
// available tensor operations:
@@ -2464,6 +2466,8 @@ extern "C" {
bool lower,
bool uni);
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
+11
View File
@@ -102,6 +102,9 @@ typedef sycl::half2 ggml_half2;
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
#define QR_MXFP4 2
#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4))
#define QR_NVFP4 2
#define QI5_0 (QK5_0 / (4 * QR5_0))
#define QR5_0 2
@@ -194,6 +197,14 @@ typedef struct {
} block_mxfp4;
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
#define QK_NVFP4 64
#define QK_NVFP4_SUB 16 // sub-block size for per-group scales
typedef struct {
uint8_t d[QK_NVFP4/QK_NVFP4_SUB]; // UE4M3 scales (4 bytes, one per 16-element sub-block)
uint8_t qs[QK_NVFP4/2]; // packed 4-bit E2M1 values (32 bytes)
} block_nvfp4;
static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding");
#define QK5_0 32
typedef struct {
ggml_half d; // delta
+8
View File
@@ -15,6 +15,7 @@
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -79,6 +80,8 @@
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// quants.c
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
@@ -108,6 +111,7 @@
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
@@ -155,6 +159,7 @@
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -201,6 +206,7 @@
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
@@ -240,6 +246,7 @@
#elif defined(__s390x__)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -302,6 +309,7 @@
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
+84
View File
@@ -650,6 +650,90 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
*s = sumf;
}
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_NVFP4 == 0);
const block_nvfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
// Each NVFP4 super-block (64 elements) spans 2 q8_0 blocks
const int nb = n / QK_NVFP4;
float sumf = 0;
#if defined __ARM_NEON
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
float32x4_t acc = vdupq_n_f32(0.0f);
for (int ib = 0; ib < nb; ++ib) {
const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs);
const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16);
const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_0, m4b));
const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4));
const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b));
const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));
const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);
const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);
const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));
const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b));
const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs);
const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16);
const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b));
const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));
const int32x4_t p0 = vaddq_s32(
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
const int32x4_t p1 = vaddq_s32(
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
const int32x4_t sums = vpaddq_s32(p0, p1);
// Decode 4 UE4M3 scales to f32 and multiply with q8 scales
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
const float32x4_t nvsc = {
ggml_ue4m3_to_fp32(x[ib].d[0]),
ggml_ue4m3_to_fp32(x[ib].d[1]),
ggml_ue4m3_to_fp32(x[ib].d[2]),
ggml_ue4m3_to_fp32(x[ib].d[3])
};
const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});
acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);
}
sumf = vaddvq_f32(acc);
#else
for (int ib = 0; ib < nb; ++ib) {
for (int si = 0; si < 4; ++si) {
const float d = ggml_ue4m3_to_fp32(x[ib].d[si]);
const int q8b = si / 2;
const int q8o = (si % 2) * QK_NVFP4_SUB;
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8b].d);
int sumi_lo = 0, sumi_hi = 0;
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
const uint8_t qv = x[ib].qs[si*(QK_NVFP4_SUB/2) + j];
sumi_lo += y[2*ib + q8b].qs[q8o + j + 0] * kvalues_mxfp4[qv & 0xf];
sumi_hi += y[2*ib + q8b].qs[q8o + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >> 4];
}
sumf += dy * d * (sumi_lo + sumi_hi);
}
}
#endif
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+6
View File
@@ -270,6 +270,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_NVFP4] = {
.from_float = quantize_row_nvfp4,
.vec_dot = ggml_vec_dot_nvfp4_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q2_K] = {
.from_float = quantize_row_q2_K,
.vec_dot = ggml_vec_dot_q2_K_q8_K,
+12 -6
View File
@@ -670,6 +670,7 @@ void ggml_compute_forward_add(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -1119,6 +1120,7 @@ void ggml_compute_forward_add1(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -1247,6 +1249,7 @@ void ggml_compute_forward_acc(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -4334,6 +4337,7 @@ void ggml_compute_forward_out_prod(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -4609,6 +4613,7 @@ void ggml_compute_forward_set(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -4831,6 +4836,7 @@ void ggml_compute_forward_get_rows(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -5555,6 +5561,7 @@ void ggml_compute_forward_clamp(
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -10436,8 +10443,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * state_in_base = (const float *)src_state->data;
const int64_t rq1 = nev1 / neq1;
const int64_t rk1 = nev1 / nek1;
//const int64_t rq1 = nev1 / neq1;
//const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3;
@@ -10447,8 +10454,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence
const int64_t iq1 = iv1 / rq1;
const int64_t ik1 = iv1 / rk1;
const int64_t iq1 = iv1 % neq1;
const int64_t ik1 = iv1 % nek1;
const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
@@ -10468,7 +10475,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
if (kda) {
for (int64_t i = 0; i < S_v; ++i) {
@@ -10501,7 +10508,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
attn_data += S_v * H; // advance to next token
}
}
}
+40
View File
@@ -50,6 +50,10 @@ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
quantize_row_mxfp4_ref(x, y, k);
}
void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
quantize_row_nvfp4_ref(x, y, k);
}
//
// 2-6 bit quantization in super-blocks
//
@@ -216,6 +220,42 @@ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
*s = sumf;
}
// NVFP4: super-block of 64 elements = 4 sub-blocks of 16 = 2 q8_0 blocks
void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_NVFP4 == 0);
const block_nvfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const int nb = n / QK_NVFP4;
float sumf = 0;
for (int ib = 0; ib < nb; ++ib) {
for (int s_idx = 0; s_idx < 4; ++s_idx) {
const float d = ggml_ue4m3_to_fp32(x[ib].d[s_idx]);
const int q8_block = s_idx / 2;
const int q8_off = (s_idx % 2) * QK_NVFP4_SUB;
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
int sumi_lo = 0, sumi_hi = 0;
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_mxfp4[qv & 0xf];
sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >> 4];
}
sumf += dy * d * (sumi_lo + sumi_hi);
}
}
*s = sumf;
}
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
+3
View File
@@ -20,6 +20,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -42,6 +43,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -73,6 +75,7 @@ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+102 -50
View File
@@ -1,5 +1,4 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"
template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
@@ -21,15 +20,17 @@ __global__ void gated_delta_net_cuda(const float * q,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t rq1,
int64_t rq3,
const uint3 neqk1_magic,
const uint3 rq3_magic,
float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
const uint32_t h_idx = blockIdx.x;
const uint32_t sequence = blockIdx.y;
// each warp owns one column, using warp-level primitives to reduce across rows
const int lane = threadIdx.x;
const int col = blockIdx.z * blockDim.y + threadIdx.y;
const int64_t iq1 = h_idx / rq1;
const int64_t iq3 = sequence / rq3;
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
@@ -40,11 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
// Load state column into registers
float s[S_v];
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
}
for (int t = 0; t < n_tokens; t++) {
@@ -62,46 +66,61 @@ __global__ void gated_delta_net_cuda(const float * q,
const float g_val = expf(*g_t);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += s[i] * k_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = g_val * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += expf(g_t[i]) * s[i] * k_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
}
}
attn_data += S_v * H;
@@ -109,45 +128,74 @@ __global__ void gated_delta_net_cuda(const float * q,
// Write state back to global memory
#pragma unroll
for (int i = 0; i < S_v; i++) {
state[i * S_v + col] = s[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
}
}
static size_t calculate_smem(const int sv, int cc)
{
size_t smem = 0;
if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
smem = sv * sv * sizeof(float);
}
return smem;
}
template <bool KDA>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t rq1, int64_t rq3,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const int num_warps = 4;
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
const uint3 rq3_magic = init_fastdiv_values(rq3);
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
switch (S_v) {
case 16:
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
case 64: {
constexpr int sv = 64;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
}
case 128: {
constexpr int sv = 128;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
default:
GGML_ABORT("fatal error");
break;
@@ -163,10 +211,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
@@ -175,7 +225,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
const bool kda = (src_g->ne[0] == S_v);
const int64_t rq1 = nev1 / neq1;
GGML_ASSERT(neq1 == nek1);
const int64_t neqk1 = neq1;
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
@@ -214,10 +266,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
}
}
+4 -1
View File
@@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
int row = tid / load_cols;
int col = tid % load_cols;
#pragma unroll
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
if (row < (int)split_d_inner) {
smem[row * n_cols + col] = x_block[row * stride_x + col];
}
@@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
col += split_d_inner;
row += col / load_cols;
col = col % load_cols;
if (idx >= total_elems - tid - split_d_inner) {
break;
}
}
__syncthreads();
+4
View File
@@ -11,6 +11,10 @@ endif()
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
if (NOT DEFINED CMAKE_HIP_FLAGS_DEBUG)
set(CMAKE_HIP_FLAGS_DEBUG "-g -O2")
endif()
# CMake on Windows doesn't support the HIP language yet
if (WIN32)
set(CXX_IS_HIPCC TRUE)
+55
View File
@@ -491,6 +491,61 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits
// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float)
static inline float ggml_ue4m3_to_fp32(uint8_t x) {
if (x == 0 || x == 0x7F) {
return 0.0f;
}
int exp = (x >> 3) & 0xF;
int man = x & 0x7;
float raw;
if (exp == 0) {
raw = ldexpf((float) man, -9);
} else {
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
}
return raw * 0.5f;
}
static inline uint8_t ggml_fp32_to_ue4m3(float x) {
if (!(x > 0.0f)) {
return 0;
}
if (x > 448.0f) {
x = 448.0f;
}
uint32_t bits;
memcpy(&bits, &x, 4);
int fp32_exp = ((bits >> 23) & 0xFF) - 127;
int fp32_man = (bits >> 20) & 0x7;
int ue4m3_exp = fp32_exp + 7;
if (ue4m3_exp <= 0) {
// subnormal: value = man * 2^-9, man = round(x * 2^9)
int man = (int) (x * 512.0f + 0.5f);
if (man > 7) {
man = 7;
}
if (man < 1) {
return 0;
}
return (uint8_t) man;
}
if (ue4m3_exp >= 15) {
return 0x7E;
}
int round_bit = (bits >> 19) & 1;
int ue4m3_man = fp32_man + round_bit;
if (ue4m3_man > 7) {
ue4m3_man = 0;
ue4m3_exp++;
if (ue4m3_exp >= 15) {
return 0x7E;
}
}
return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man);
}
/**
* Converts brain16 to float32.
*
+25 -8
View File
@@ -47,7 +47,7 @@ struct ggml_metal {
uint64_t fuse_cnt[GGML_OP_COUNT];
// capture state
bool capture_next_compute;
int capture_compute;
bool capture_started;
id<MTLCaptureScope> capture_scope;
@@ -158,10 +158,17 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
res->capture_next_compute = false;
res->capture_compute = 0;
res->capture_started = false;
res->capture_scope = nil;
{
const char * val = getenv("GGML_METAL_CAPTURE_COMPUTE");
if (val) {
res->capture_compute = atoi(val);
}
}
res->has_error = false;
res->gf = nil;
@@ -458,9 +465,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
const bool use_capture = ctx->capture_next_compute;
if (ctx->capture_compute >= 0) {
ctx->capture_compute--;
}
const bool use_capture = ctx->capture_compute == 0;
if (use_capture) {
ctx->capture_next_compute = false;
ctx->capture_compute = -1;
// make sure all previous computations have finished before starting the capture
if (ctx->cmd_buf_last) {
@@ -469,6 +480,10 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
}
if (!ctx->capture_started) {
NSString * path = [NSString stringWithFormat:@"/tmp/perf-metal-%d.gputrace", getpid()];
GGML_LOG_WARN("%s: capturing graph in %s\n", __func__, [path UTF8String]);
// create capture scope
id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
@@ -476,7 +491,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->capture_scope;
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
descriptor.outputURL = [NSURL fileURLWithPath:path];
NSError * error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
@@ -539,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
// enter here only when capturing in order to wait for all computation to finish
// otherwise, we leave the graph to compute asynchronously
if (!use_capture && ctx->capture_started) {
if (use_capture && ctx->capture_started) {
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
@@ -591,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
[ctx->capture_scope endScope];
[[MTLCaptureManager sharedCaptureManager] stopCapture];
ctx->capture_started = false;
}
}
@@ -683,7 +700,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
idx_end,
ctx->use_fusion,
ctx->use_concurrency,
ctx->capture_next_compute,
ctx->capture_compute,
ctx->debug_graph,
ctx->debug_fusion);
@@ -718,5 +735,5 @@ bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
}
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
ctx->capture_next_compute = true;
ctx->capture_compute = 1;
}
+38 -1
View File
@@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
// v is src[2], dimensions: S_v = ne[0], H = ne[1]
const int ne20 = op->src[2]->ne[0]; // S_v
const int ne21 = op->src[2]->ne[1]; // H
const int ne30 = op->src[3]->ne[0]; // G
const int nsg = op->src[2]->ne[0]/32;
GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(op->ne[0] == ne20 * ne21);
GGML_ASSERT(ne20 % 32 == 0);
snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.nsg = nsg;
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
@@ -1435,10 +1470,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
@@ -1447,6 +1483,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+1
View File
@@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+4 -2
View File
@@ -1155,10 +1155,12 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_GATED_DELTA_NET:
return op->src[2]->ne[0] % 32 == 0;
case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction;
return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4;
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_DUP:
@@ -1216,7 +1218,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
};
}
case GGML_OP_GET_ROWS:
return true;
return op->src[0]->type != GGML_TYPE_NVFP4;
case GGML_OP_SET_ROWS:
{
if (op->src[0]->type != GGML_TYPE_F32) {
+40 -1
View File
@@ -35,7 +35,7 @@
#define N_R0_Q4_K 2
#define N_SG_Q4_K 2
#define N_R0_Q5_K 2
#define N_R0_Q5_K 1
#define N_SG_Q5_K 2
#define N_R0_Q6_K 2
@@ -84,6 +84,7 @@
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
#define FC_GATED_DELTA_NET 1600
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -793,6 +794,44 @@ typedef struct {
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne20;
int32_t ne21;
int32_t ne22;
int32_t ne23;
uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ns02;
int32_t ns12;
int32_t ns22;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_gated_delta_net;
typedef struct {
int32_t ne00;
int32_t ne01;
+80 -3
View File
@@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_rwkv(ctx, idx);
} break;
case GGML_OP_GATED_DELTA_NET:
{
n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
} break;
case GGML_OP_SOLVE_TRI:
{
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
@@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
int ida = 0;
ggml_metal_kargs_gated_delta_net args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne20 =*/ ne20,
/*.ne21 =*/ ne21,
/*.ne22 =*/ ne22,
/*.ne23 =*/ ne23,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
/*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
/*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst
const int nsg = pipeline.nsg;
ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
return 1;
}
int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@@ -3101,9 +3180,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+1
View File
@@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
+229 -4
View File
@@ -1111,6 +1111,7 @@ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_un
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
@@ -1124,11 +1125,12 @@ kernel void kernel_bin_fuse_impl(
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
#define FC_CB FC_bin_cb
if (FC_RB) {
// row broadcast
const uint i0 = tgpig.x;
const uint i1 = i0%args.ne10;
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
@@ -1200,7 +1202,7 @@ kernel void kernel_bin_fuse_impl(
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
const int i10 = FC_CB ? i0%args.ne10 : i0;
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
@@ -1225,7 +1227,7 @@ kernel void kernel_bin_fuse_impl(
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
const int i10 = FC_CB ? i0%args.ne10 : i0;
T res = src0_ptr[i0];
@@ -1261,6 +1263,7 @@ kernel void kernel_bin_fuse_impl(
#undef FC_OP
#undef FC_F
#undef FC_RB
#undef FC_CB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
@@ -2434,6 +2437,227 @@ kernel void kernel_rwkv_wkv7_f32(
}
}
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
#if 1
template<short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is*S_v];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
if (G == 1) {
const float g_exp = exp(g_ptr[0]);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= g_exp;
s_k += ls[j]*k_ptr[is];
}
} else {
// KDA
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= exp(g_ptr[is]);
s_k += ls[j]*k_ptr[is];
}
}
s_k = simd_sum(s_k);
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
float y = 0.0f;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] += k_ptr[is]*d;
y += ls[j]*q_ptr[is];
}
y = simd_sum(y);
if (tx == 0) {
dst_attn[t*args.ne21*S_v] = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = ls[j];
}
#undef S_v
#undef G
}
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
#else
// a simplified version of the above
// no performance improvement, so keep the above version for now
template<typename T, short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
float lsf[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
lsf[j] = s_ptr[is*S_v];
}
thread T * ls = (thread T *) (lsf);
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
for (short t = 0; t < args.ne22; t++) {
device const T * qt_ptr = (device const T *) (q_ptr);
device const T * kt_ptr = (device const T *) (k_ptr);
device const T * gt_ptr = (device const T *) (g_ptr);
if (G == 1) {
*ls *= exp(g_ptr[0]);
} else {
// KDA
*ls *= exp(gt_ptr[tx]);
}
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
*ls += kt_ptr[tx]*d;
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
if (tx == 0) {
*dst_attn = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
dst_attn += args.ne21*S_v;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device T * dstt_state = (device T *) (dst_state);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = lsf[j];
}
#undef S_v
#undef G
}
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
#endif
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
@@ -9081,6 +9305,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
+1
View File
@@ -132,6 +132,7 @@ set(GGML_OPENCL_KERNELS
ssm_conv
sub
sum_rows
cumsum
transpose
concat
tsembd
+153 -15
View File
@@ -547,6 +547,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
cl_kernel kernel_argsort_f32_i32;
cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4;
cl_kernel kernel_cumsum_blk, kernel_cumsum_add;
cl_kernel kernel_repeat_f32;
cl_kernel kernel_pad;
cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
@@ -1927,6 +1928,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// cumsum
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "cumsum.cl.h"
};
#else
const std::string kernel_src = read_file("cumsum.cl");
#endif
cl_program prog;
prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_cumsum_blk = clCreateKernel(prog, "kernel_cumsum_blk", &err), err));
CL_CHECK((backend_ctx->kernel_cumsum_add = clCreateKernel(prog, "kernel_cumsum_add", &err), err));
GGML_LOG_CONT(".");
CL_CHECK(clReleaseProgram(prog));
}
// sigmoid
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3803,6 +3822,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
}
case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_FLASH_ATTN_EXT:
@@ -5775,19 +5796,12 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
const int ne00 = src0->ne[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = src1->ne[0];
const cl_ulong nb10 = src1->nb[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
GGML_TENSOR_LOCALS(int, ne0, src0, ne);
GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
GGML_TENSOR_LOCALS(int, ne1, src1, ne);
GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb);
GGML_TENSOR_LOCALS(int, ne, dst, ne);
GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -5833,8 +5847,14 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
size_t local_work_size[] = {64, 1, 1};
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
int nth = 1;
while (nth < ne00 && 2*nth <= max_workgroup_size) {
nth *= 2;
}
size_t global_work_size[] = {(size_t)ne10*nth, (size_t)ne11, (size_t)ne12};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
@@ -11949,6 +11969,118 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
GGML_UNUSED(src1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
GGML_TENSOR_LOCALS(int, ne0, src0, ne);
GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
cl_kernel kernel = backend_ctx->kernel_cumsum_blk;
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
int nth = 1;
while (nth < ne00 && 2*nth <= max_workgroup_size) {
nth *= 2;
}
GGML_ASSERT(ne00 <= nth*nth);
const int net0 = CEIL_DIV(ne00, nth);
const int net1 = ne01;
const int net2 = ne02;
const int net3 = ne03;
const cl_ulong nbt0 = sizeof(float);
const cl_ulong nbt1 = net0*nbt0;
const cl_ulong nbt2 = net1*nbt1;
const cl_ulong nbt3 = net2*nbt2;
static ggml_cl_buffer tmp_buffer;
tmp_buffer.allocate(backend_ctx->context, net0*ne01*ne02*ne03*sizeof(float));
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp_buffer.buffer));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2));
size_t global_work_size[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = { (size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
if(ne00 > nth) {
// if a single workgroup cannot handle an entire row, each workgroup
// computes a partial sum and stores to dst, tmp_buffer contains the sum
// of the each workgroup; cumsum this buffer and add to the partial sums in dst
cl_ulong offsett = 0;
kernel = backend_ctx->kernel_cumsum_blk;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp_buffer.buffer));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsett));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp_buffer.buffer));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tmp_buffer.buffer));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsett));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &net0));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nbt0));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nbt1));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nbt2));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nbt3));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2));
size_t global_work_size_1[] = { (size_t)net1*nth, (size_t)net2, (size_t)net3};
size_t local_work_size_1[] = { (size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_1, local_work_size_1, dst);
kernel = backend_ctx->kernel_cumsum_add;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp_buffer.buffer));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &nbt0));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &nbt1));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &nbt2));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &nbt3));
size_t global_work_size_2[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03};
size_t local_work_size_2[] = { (size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_2, local_work_size_2, dst);
}
}
static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -12391,6 +12523,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_sum_rows;
break;
case GGML_OP_CUMSUM:
if (!any_on_device) {
return false;
}
func = ggml_cl_cumsum;
break;
case GGML_OP_FLASH_ATTN_EXT:
if (!any_on_device) {
return false;
+139
View File
@@ -0,0 +1,139 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
// max workgroup size is usually 1024, this covers various subgroups sizes
#define MAX_SUBGROUPS 128
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_cumsum_blk(
global char * src0,
ulong offset0,
global char * tmp,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
uint net0,
uint net1,
uint net2
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int nth = get_local_size(0);
const int tid = get_local_id(0);
const uint sg_size = get_sub_group_size();
const uint sg_id = get_sub_group_id();
const uint sg_lid = get_sub_group_local_id();
const int ib = i1 / ne01;
const int i00 = ib * nth;
const int i01 = i1 % ne01;
const int i02 = i2;
const int i03 = i3;
global const float * src0_row = (global const float *)(src0 + i03*nb03 + i02*nb02 + i01*nb01);
global float * tmp_row = (global float *)tmp + net0 * i01 + net0 * net1 * i02 + net0 * net1 * net2 * i03;
global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
__local float partial[MAX_SUBGROUPS];
float v = 0.0f;
if (i00 + tid < ne00) {
v = src0_row[i00 + tid];
}
float s = sub_group_scan_inclusive_add(v);
if (sg_lid == sg_size - 1) {
partial[sg_id] = s;
}
barrier(CLK_LOCAL_MEM_FENCE);
// NB: subgroup size should be larger than number of subgroups
// assuming max workgroup size of 1024, subgroup size should be >= 32
if (sg_id == 0) {
float x = 0.0f;
if (sg_lid < get_num_sub_groups()) {
x = partial[sg_lid];
}
float ex = sub_group_scan_exclusive_add(x);
if (sg_lid < get_num_sub_groups()) {
partial[sg_lid] = ex;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
s += partial[sg_id];
if (i00 + tid < ne00) {
dst_row[i00 + tid] = s;
}
if (ne00 > nth && tid == nth - 1) {
tmp_row[ib] = s;
}
}
kernel void kernel_cumsum_add(
global char * tmp,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
uint nbt0,
uint nbt1,
uint nbt2,
uint nbt3
) {
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
const int nth = get_local_size(0);
const int tid = get_local_id(0);
const int ib = i1 / ne01;
if (ib == 0) {
return;
}
const int i00 = ib * nth;
const int i01 = i1 % ne01;
const int i02 = i2;
const int i03 = i3;
global float * tmp_row = (global float *)(tmp + nbt1 * i01 + nbt2 * i02 + nbt3 * i03);
global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
if (i00 + tid < ne00) {
dst_row[i00 + tid] += tmp_row[ib - 1];
}
}
+72
View File
@@ -304,6 +304,41 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
}
}
void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_NVFP4;
static const int qk_sub = QK_NVFP4_SUB;
static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
for (int s = 0; s < n_sub; s++) {
const float * xb = x + i*qk + s*qk_sub;
float amax = 0.0f;
for (int j = 0; j < qk_sub; j++) {
if (amax < fabsf(xb[j])) {
amax = fabsf(xb[j]);
}
}
// UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax
const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);
y[i].d[s] = ue;
const float d = ggml_ue4m3_to_fp32(ue);
for (int j = 0; j < qk_sub/2; ++j) {
const uint8_t x0 = best_index_mxfp4(xb[0 + j], d);
const uint8_t x1 = best_index_mxfp4(xb[qk_sub/2 + j], d);
y[i].qs[s*(qk_sub/2) + j] = x0 | (x1 << 4);
}
}
}
}
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
@@ -434,6 +469,31 @@ void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_REST
}
}
void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK_NVFP4;
static const int qk_sub = QK_NVFP4_SUB;
static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
for (int s = 0; s < n_sub; s++) {
const float d = ggml_ue4m3_to_fp32(x[i].d[s]);
float * yb = y + i*qk + s*qk_sub;
for (int j = 0; j < qk_sub/2; ++j) {
const int8_t v0 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] & 0x0F];
const int8_t v1 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] >> 4];
yb[j + 0 ] = v0*d;
yb[j + qk_sub/2] = v1*d;
}
}
}
}
//
// 2-6 bit quantization in super-blocks
//
@@ -2098,6 +2158,12 @@ size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
}
size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
quantize_row_nvfp4_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -5244,6 +5310,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
} break;
case GGML_TYPE_NVFP4:
{
// UE4M3 scales are uint8_t — all byte values are valid
GGML_UNUSED(data);
GGML_UNUSED(nb);
} break;
case GGML_TYPE_Q2_K:
{
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
+3
View File
@@ -22,6 +22,7 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 *
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
@@ -48,6 +49,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -95,6 +97,7 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);
+91
View File
@@ -874,4 +874,95 @@ static bool fast_fp16_available(const int cc) {
return true; //Intel GPUs always support FP16.
}
enum class block_reduce_method {
MAX,
SUM,
};
template<block_reduce_method method_t, typename T, int warp_size>
struct block_reduce_policy;
template <typename T, typename... Ts>
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
template<typename...>
inline constexpr bool ggml_sycl_dependent_false_v = false;
#define WARP_32_SIZE 32
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::SUM, T, warp_size> {
static T reduce(T val) {
if constexpr (is_any<T, float, sycl::float2, sycl::half2, int>) {
return warp_reduce_sum<warp_size>(val);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
static T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return 0.0f;
} else if constexpr (std::is_same_v<T, sycl::float2>) {
return sycl::float2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, sycl::half2>) {
return sycl::half2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, int>) {
return 0;
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
};
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::MAX, T, warp_size> {
static T reduce(T val) {
if constexpr (is_any<T, float, sycl::half2>) {
return warp_reduce_max<warp_size>(val);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
static T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return -INFINITY;
} else if constexpr (std::is_same_v<T, sycl::half2>) {
return sycl::half2(-INFINITY, -INFINITY);
} else {
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
};
template <block_reduce_method reduce_method_t, int warp_size, typename T>
static T block_reduce(T val, T * shared_vals, int block_size_template) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
val = block_reduce_policy<reduce_method_t, T,warp_size>::reduce(val);
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
if (block_size > warp_size) {
assert((block_size <= 1024) && (block_size % warp_size) == 0);
const int warp_id = item_ct1.get_local_id(2) / warp_size;
const int lane_id = item_ct1.get_local_id(2) % warp_size;
if (lane_id == 0) {
shared_vals[warp_id] = val;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
float tmp = 0.f;
if (lane_id < (static_cast<int>(block_size) / warp_size)) {
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += shared_vals[lane_id + i * WARP_SIZE];
}
}
return block_reduce_policy<reduce_method_t, T, warp_size>::reduce(tmp);
}
return val;
}
#endif // GGML_SYCL_COMMON_HPP
+6
View File
@@ -39,6 +39,11 @@ template<typename dst_t, typename src_t>
return sycl::ext::oneapi::bfloat16(float(x));
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
return static_cast<float>(x);
} else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {
return x.template convert<sycl::half, sycl::rounding_mode::rte>();
} else if constexpr (std::is_same_v<src_t, sycl::float2> &&
std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {
return {x.x, x.y};
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
@@ -46,4 +51,5 @@ template<typename dst_t, typename src_t>
}
}
#endif // GGML_SYCL_CONVERT_HPP
+55 -58
View File
@@ -9,23 +9,32 @@
#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
(ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2);
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
if (i >= ne) {
return;
}
int src1_idx = i - offset;
int oz = src1_idx / nb2;
int oy = (src1_idx - (oz * nb2)) / nb1;
int ox = src1_idx % nb1;
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
} else {
dst[i] = x[i];
int64_t src1_idx = i - offset;
int64_t tmp = src1_idx;
const int64_t i13 = tmp / s13;
tmp -= i13 * s13;
const int64_t i12 = tmp / s12;
tmp -= i12 * s12;
const int64_t i11 = tmp / s11;
tmp -= i11 * s11;
const int64_t i10 = tmp;
float val = x[i];
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
}
dst[i] = val;
}
/* Unary OP funcs */
@@ -364,18 +373,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const
namespace ggml_sycl_detail {
static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11,
const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) {
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
item_ct1);
});
const int64_t n_elements, const int64_t ne10, const int64_t ne11,
const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3,
const int64_t offset, queue_ptr stream) {
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
});
}
template<typename T>
@@ -402,25 +408,19 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@@ -434,14 +434,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const ggml_tensor * src0 = dst->src[0];
@@ -463,7 +459,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
GGML_ASSERT(src0->type == src1->type);
}
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
sycl::half * src0_p = (sycl::half *) src0_d;
@@ -484,7 +479,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
float * src0_p = (float *) src0_d;
@@ -513,13 +507,9 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
@@ -530,7 +520,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
@@ -539,7 +528,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
@@ -868,22 +856,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens
}
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
float * dst_dd = static_cast<float *>(dst->data);
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
const int64_t s1 = dst->op_params[0] / sizeof(float);
const int64_t s2 = dst->op_params[1] / sizeof(float);
const int64_t s3 = dst->op_params[2] / sizeof(float);
const int64_t offset = dst->op_params[3] / sizeof(float);
ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst),
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
s1, s2, s3, offset, stream);
}
static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+6 -1
View File
@@ -4145,6 +4145,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ROPE:
ggml_sycl_rope(ctx, dst);
break;
case GGML_OP_ROPE_BACK:
ggml_sycl_rope_back(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_sycl_im2col(ctx, dst);
break;
@@ -4851,6 +4854,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return max_bias == 0.0f;
}
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_IM2COL:
return true;
case GGML_OP_UPSCALE:
@@ -4872,8 +4876,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
k > 0 && k <= 32;
}
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
case GGML_OP_ACC:
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
case GGML_OP_PAD:
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
if (ggml_get_op_params_i32(op, 8) != 0) {
+65 -63
View File
@@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
}
}
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE;
template<int warp_size>
static void l2_norm_f32(const float * x, float * dst, const int ncols,
const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
const int row = item_ct1.get_group(2);
const int channel = item_ct1.get_group(1);
const int sample = item_ct1.get_group(0);
const int tid = item_ct1.get_local_id(2);
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row * ncols + col];
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp, item_ct1);
if (block_size > WARP_SIZE) {
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
/*
DPCT1118:3: SYCL group functions and algorithms must be encountered in
converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
size_t nreduce = nwarps / WARP_SIZE;
tmp = 0.f;
for (size_t i = 0; i < nreduce; i += 1)
{
tmp += s_sum[lane_id + i * WARP_SIZE];
}
tmp = warp_reduce_sum(tmp, item_ct1);
}
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size);
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[row * ncols + col] = scale * x[row * ncols + col];
dst[col] = scale * x[col];
}
}
@@ -369,42 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
}
}
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
template<int warp_size>
static void l2_norm_f32_sycl(const float * x,
float * dst,
const int ncols,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
queue_ptr stream,
int device) {
const dpct::dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
const dpct::dim3 block_dims(warp_size, 1, 1);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
sycl::nd_range<3>(blocks_num * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
nullptr, warp_size);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
assert(work_group_size % (warp_size * warp_size) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
int lsm_size = block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0;
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(lsm_size),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
sycl::nd_range<3>(blocks_num * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
@@ -634,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
}
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t ne00 = dst->src[0]->ne[0];
const int64_t nrows = ggml_nrows(dst->src[0]);
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
float * dst_dd = static_cast<float *>(dst->data);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
/*support both WARP_SIZE or WARP_32_SIZE in code
choose by hardware for better performance
*/
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
}
+447 -283
View File
@@ -1,4 +1,5 @@
#include "rope.hpp"
#include "convert.hpp"
#include "ggml-sycl/common.hpp"
#include "ggml.h"
@@ -15,366 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
template <bool forward>
static void rope_yarn(const float theta_extrap, const float freq_scale,
const rope_corr_dims corr_dims, const int64_t i0,
const float ext_factor, float mscale, float &cos_theta,
float &sin_theta) {
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
float ramp_mix =
rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
}
*cos_theta = sycl::cos(theta) * mscale;
*sin_theta = sycl::sin(theta) * mscale;
cos_theta = sycl::cos(theta) * mscale;
sin_theta = sycl::sin(theta) * mscale;
if (!forward) {
sin_theta *= -1.0f;
}
}
template <typename T, bool has_ff>
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
template <bool forward, bool has_ff, typename T, typename D>
static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const int64_t *row_indices, const int set_rows_stride) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row0 = row % ne1;
const int channel0 = row / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int i = row * ne0 + i0;
const int i2 = channel0 * s2 + row0 * s1 + i0;
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
if (set_rows_stride != 0) {
idst = i1 * s1 + i0;
idst += row_indices[i2] * set_rows_stride;
}
const auto &store_coaelsced = [&](float x0, float x1) {
if constexpr (std::is_same_v<float, D>) {
sycl::float2 v = sycl::float2(x0, x1);
ggml_sycl_memcpy_1<8>(dst + idst, &v);
} else if constexpr (std::is_same_v<sycl::half, D>) {
sycl::half2 v = sycl::half2(x0, x1);
ggml_sycl_memcpy_1<4>(dst + idst, &v);
}
};
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
store_coaelsced(x[ix + 0], x[ix + 1]);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i2 + 0];
const float x1 = x[i2 + 1];
const float x0 = x[ix + 0];
const float x1 = x[ix + 1];
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
store_coaelsced(x0 * cos_theta - x1 * sin_theta,
x0 * sin_theta + x1 * cos_theta);
}
template <typename T, bool has_ff>
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
const sycl::nd_item<3> & item_ct1) {
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
template <bool forward, bool has_ff, typename T, typename D>
static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const int64_t *row_indices, const int set_rows_stride) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const int row0 = row % ne1;
const int channel0 = row / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int i = row * ne0 + i0 / 2;
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
if (set_rows_stride != 0) {
idst = i1 * s1 + i0 / 2;
idst += row_indices[i2] * set_rows_stride;
}
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
return;
}
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[i2 + 0];
const float x1 = x[i2 + n_dims / 2];
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims / 2];
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
}
template <typename T, bool has_ff>
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
template <bool forward, bool has_ff, typename T>
static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const mrope_sections sections, const bool is_imrope) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
if (i0 >= ne00) {
return;
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
if (i0 >= n_dims) {
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
return;
}
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sect_dims =
sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2];
// store results in dst
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims / 2];
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
}
template <bool forward, bool has_ff, typename T>
static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02,
const int s03, const int s1, const int s2, const int s3,
const int n_dims, const int32_t *pos,
const float freq_scale, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float *freq_factors,
const mrope_sections sections) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
template <typename T, bool has_ff>
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = (row_dst * ne0) + (i0 / 2);
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
const int sect_dims = sections.v[0] + sections.v[1];
const int sector = (i0 / 2) % sect_dims;
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
} else {
theta_base = pos[i2] * dpct::pow(theta_scale, p);
} else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
float cos_theta;
float sin_theta;
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
ext_factor, attn_factor, cos_theta, sin_theta);
const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims];
// store results in dst
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
}
template <typename T>
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> block_nums(1, num_blocks_x, nr);
template <bool forward, typename T, typename D>
static void
rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const int64_t *row_indices,
const int set_rows_stride, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) {
/*
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_norm<forward, false>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
} else {
/*
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_norm<forward, true>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
}
}
template <typename T>
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> block_nums(1, num_blocks_x, nr);
template <bool forward, typename T, typename D>
static void
rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const int64_t *row_indices,
const int set_rows_stride, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_neox<forward, false>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
} else {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_neox<forward, true>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, row_indices, set_rows_stride);
});
}
}
template <typename T>
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
template <bool forward, typename T>
static void
rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const mrope_sections sections,
const bool is_imrope, dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
// Add FP16 capability check if T could be sycl::half
if constexpr (std::is_same_v<T, sycl::half>) {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_multi<forward, false, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections, is_imrope);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_multi<forward, true, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections, is_imrope);
});
}
}
template <bool forward, typename T>
static void
rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
const int ne02, const int s01, const int s02, const int s03,
const int s1, const int s2, const int s3, const int n_dims,
const int nr, const int32_t *pos, const float freq_scale,
const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims,
const float *freq_factors, const mrope_sections sections,
dpct::queue_ptr stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x =
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
// rope vision
template <typename T>
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
// Add FP16 capability check if T could be sycl::half
if constexpr (std::is_same_v<T, sycl::half>) {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_vision<forward, false, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
GGML_UNUSED(item_ct1);
rope_vision<forward, true, T>(
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections);
});
}
}
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
template <bool forward>
void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
const ggml_tensor *set_rows = nullptr) {
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
const ggml_tensor *src2 = dst->src[2];
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(dst->src[0]->type == dst->type);
const int64_t ne00 = dst->src[0]->ne[0]; // head dims
const int64_t ne01 = dst->src[0]->ne[1]; // num heads
const int64_t ne02 = dst->src[0]->ne[2]; // num heads
const int64_t nr = ggml_nrows(dst->src[0]);
const float *src0_d = (const float *)src0->data;
const float *src1_d = (const float *)src1->data;
const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
void *dst_d = dst->data;
const int64_t *row_indices = nullptr;
ggml_type dst_type = dst->type;
int set_rows_stride = 0;
if (set_rows != nullptr) {
GGML_ASSERT(forward);
dst_d = set_rows->data;
row_indices = (const int64_t *)set_rows->src[1]->data;
dst_type = set_rows->type;
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
}
dpct::queue_ptr stream = ctx.stream();
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type ||
(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
const int n_dims = ((int32_t *)dst->op_params)[1];
const int mode = ((int32_t *)dst->op_params)[2];
const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
mrope_sections sections;
// RoPE alteration for extended context
float freq_base;
float freq_scale;
float ext_factor;
@@ -382,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
@@ -396,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
sections.v[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
GGML_ASSERT(n_dims == ne00 / 2);
}
const int32_t * pos = (const int32_t *) dst->src[1]->data;
const int32_t *pos = (const int32_t *)src1_d;
const float * freq_factors = nullptr;
if (dst->src[2] != nullptr) {
freq_factors = (const float *) dst->src[2]->data;
const float *freq_factors = nullptr;
if (src2 != nullptr) {
freq_factors = (const float *)src2->data;
}
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
beta_slow, corr_dims.v);
// compute
if (is_neox) {
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_sycl<forward, float, float>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_sycl<forward, float, sycl::half>(
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_sycl<forward, sycl::half, sycl::half>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else if (is_mrope && !is_vision) {
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, is_imrope, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
is_imrope, main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
ne00, ne01, ne02, s01, s02, s03, s1, s2,
s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims,
freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_sycl<forward>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
sections, is_imrope, stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else if (is_vision) {
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_vision_sycl<forward>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, sections,
stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_sycl<forward>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
sections, stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
} else {
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
if (dst->src[0]->type == GGML_TYPE_F32) {
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
main_stream);
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_sycl<forward, float, float>(
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_sycl<forward, float, sycl::half>(
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_sycl<forward, sycl::half, sycl::half>(
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
row_indices, set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
GGML_ABORT("Fatal error: Tensor type unsupported!");
}
}
}
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
ggml_sycl_op_rope(ctx, dst);
ggml_sycl_op_rope_impl<true>(ctx, dst);
}
void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
ggml_sycl_op_rope_impl<false>(ctx, dst);
}
void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
ggml_tensor *set_rows) {
scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
}
+6
View File
@@ -15,6 +15,12 @@
#include "common.hpp"
#define SYCL_ROPE_BLOCK_SIZE 256
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
#endif // GGML_SYCL_ROPE_HPP
+203 -51
View File
@@ -27,6 +27,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#include <iostream>
#include <tuple>
#include <vector>
#include <deque>
#include <sstream>
#include <utility>
#include <memory>
@@ -188,6 +189,11 @@ struct ggml_backend_vk_buffer_type_context {
struct vk_queue;
struct vk_command_buffer {
vk::CommandBuffer buf;
bool in_use = false;
};
// Stores command pool/buffers. There's an instance of this
// for each (context,queue) pair and for each (device,queue) pair.
struct vk_command_pool {
@@ -195,10 +201,16 @@ struct vk_command_pool {
void destroy(vk::Device& device);
vk::CommandPool pool;
uint32_t cmd_buffer_idx;
std::vector<vk::CommandBuffer> cmd_buffers;
// Using deque so the pointers to command buffers
// remain valid even if we add more
std::deque<vk_command_buffer> cmd_buffers;
vk_queue *q;
size_t buffers_in_use() const {
return std::count_if(cmd_buffers.begin(), cmd_buffers.end(),
[](const auto& cb) { return cb.in_use; });
}
};
// Prevent simultaneous submissions to the same queue.
@@ -813,6 +825,8 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
vk_pipeline pipeline_gated_delta_net[3][2];
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
@@ -878,10 +892,12 @@ struct vk_device_struct {
};
void vk_command_pool::init(vk_device& device, vk_queue *q_) {
cmd_buffer_idx = 0;
cmd_buffers.clear();
q = q_;
vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
vk::CommandPoolCreateInfo command_pool_create_info(
vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT),
q->queue_family_index);
pool = device->device.createCommandPool(command_pool_create_info);
}
@@ -929,6 +945,7 @@ struct vk_subbuffer {
struct vk_event {
vk::Event event;
vk::Fence fence;
vk_command_buffer* cmd_buffer = nullptr;
};
struct vk_semaphore {
@@ -937,7 +954,7 @@ struct vk_semaphore {
};
struct vk_submission {
vk::CommandBuffer buffer;
vk_command_buffer* buffer = nullptr;
std::vector<vk_semaphore> wait_semaphores;
std::vector<vk_semaphore> signal_semaphores;
};
@@ -1439,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C;
uint32_t H;
};
struct vk_op_gated_delta_net_push_constants {
uint32_t H;
uint32_t n_tokens;
uint32_t n_seqs;
uint32_t s_off;
uint32_t sq1, sq2, sq3;
uint32_t sv1, sv2, sv3;
uint32_t sb1, sb2, sb3;
uint32_t neq1, rq3;
float scale;
};
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
@@ -2283,25 +2312,15 @@ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx
}
}
static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
// Reuse command buffer
return p.cmd_buffers[p.cmd_buffer_idx++];
}
vk::CommandBufferAllocateInfo command_buffer_alloc_info(
p.pool,
vk::CommandBufferLevel::ePrimary,
1);
const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
auto buf = cmd_buffers.front();
p.cmd_buffers.push_back(buf);
p.cmd_buffer_idx++;
return buf;
p.cmd_buffers.push_back({ cmd_buffers.front(), true });
return &p.cmd_buffers[p.cmd_buffers.size()-1];
}
static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
@@ -2368,7 +2387,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
tl_wait_semaphores[idx].data(),
stage_flags[idx].data(),
1,
&submission.buffer,
&submission.buffer->buf,
(uint32_t) submission.signal_semaphores.size(),
tl_signal_semaphores[idx].data(),
};
@@ -2492,7 +2511,11 @@ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p)
// Requires command buffers to be done
device->device.resetCommandPool(p.pool);
p.cmd_buffer_idx = 0;
// Don't clear the command buffers and mark them as not in use.
// This allows us to reuse them
for (auto& cmd_buffer : p.cmd_buffers) {
cmd_buffer.in_use = false;
}
}
static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
@@ -2501,10 +2524,10 @@ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
// Arbitrary frequency to cleanup/reuse command buffers
static constexpr uint32_t cleanup_frequency = 10;
if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
}
if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
}
}
@@ -2752,7 +2775,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
}
subctx->s->buffer.pipelineBarrier(
subctx->s->buffer->buf.pipelineBarrier(
subctx->p->q->stage_flags,
subctx->p->q->stage_flags,
{},
@@ -2768,7 +2791,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
VK_LOG_DEBUG("ggml_vk_set_event()");
ctx->s->buffer.setEvent(
ctx->s->buffer->buf.setEvent(
event,
ctx->p->q->stage_flags
);
@@ -2780,7 +2803,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
return;
}
ctx->s->buffer.waitEvents(
ctx->s->buffer->buf.waitEvents(
events,
ctx->p->q->stage_flags,
ctx->p->q->stage_flags,
@@ -4559,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
{
const uint32_t gdn_sizes[] = {32, 64, 128};
const char * gdn_names[][2] = {
{"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
};
for (uint32_t si = 0; si < 3; si++) {
for (uint32_t kda = 0; kda < 2; kda++) {
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
}
}
}
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
@@ -4567,7 +4607,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -6348,13 +6388,24 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
return vk_subbuffer{buffer, offset, size};
}
// Get a command buffer from pool. Create a new one if no reusable buffer is available
static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
for (auto& cmd_buffer : pool.cmd_buffers) {
if (!cmd_buffer.in_use) {
cmd_buffer.in_use = true;
return &cmd_buffer;
}
}
return ggml_vk_create_cmd_buffer(device, pool);
}
static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
vk_submission s;
s.buffer = ggml_vk_create_cmd_buffer(device, p);
s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p);
if (one_time) {
s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
} else {
s.buffer.begin({ vk::CommandBufferUsageFlags{} });
s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} });
}
return s;
@@ -6407,18 +6458,18 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
pipeline->layout,
0,
{ descriptor_set },
{});
subctx->s->buffer.dispatch(wg0, wg1, wg2);
subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);
}
static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
s.buffer.end();
s.buffer->buf.end();
s.wait_semaphores = std::move(wait_semaphores);
s.signal_semaphores = std::move(signal_semaphores);
@@ -6430,7 +6481,7 @@ static void ggml_vk_ctx_end(vk_context& ctx) {
return;
}
ctx->s->buffer.end();
ctx->s->buffer->buf.end();
ctx->s = nullptr;
}
@@ -6584,7 +6635,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
}
ggml_vk_sync_buffers(ctx, subctx);
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
return;
}
@@ -6599,7 +6650,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
VkBufferCopy buf_copy{ 0, offset, copy_size };
ggml_vk_sync_buffers(ctx, subctx);
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
for (uint64_t i3 = 0; i3 < ne3; i3++) {
for (uint64_t i2 = 0; i2 < ne2; i2++) {
@@ -6648,7 +6699,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
}
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
return true;
}
VK_LOG_DEBUG("STAGING");
@@ -6670,7 +6721,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
copy_size};
ggml_vk_sync_buffers(nullptr, subctx);
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
if (width == spitch) {
deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
@@ -6756,7 +6807,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
if (buf != nullptr) {
// Memory is pinned, use as staging buffer
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices);
return true;
}
@@ -6774,7 +6825,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
vk_buffer& staging_buffer = src->device->sync_staging;
ggml_vk_sync_buffers(nullptr, subctx);
subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices);
deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
return true;
@@ -6821,7 +6872,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
VkBufferCopy bc{ src_offset, dst_offset, size };
vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
}
static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
@@ -6859,7 +6910,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
}
// Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
}
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
@@ -6874,7 +6925,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
ggml_vk_ctx_begin(dst->device, subctx);
subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
ggml_vk_ctx_end(subctx);
ggml_vk_submit(subctx, dst->device->fence);
@@ -8820,7 +8871,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
mask != nullptr, use_mask_opt, logit_softcap != 0);
@@ -9478,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_GATED_DELTA_NET:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t S_v = dst->src[2]->ne[0];
const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
uint32_t si;
switch (S_v) {
case 32: si = 0; break;
case 64: si = 1; break;
case 128: si = 2; break;
default: return nullptr;
}
return ctx->device->pipeline_gated_delta_net[si][kda];
}
return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
@@ -10308,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}
static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src_q = dst->src[0];
const ggml_tensor * src_v = dst->src[2];
const ggml_tensor * src_beta = dst->src[4];
GGML_ASSERT(dst->buffer != nullptr);
const uint32_t S_v = (uint32_t)src_v->ne[0];
const uint32_t H = (uint32_t)src_v->ne[1];
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer src_buf[6] = {};
for (int i = 0; i < 6; i++) {
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
}
const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));
const uint32_t neq1 = (uint32_t)src_q->ne[1];
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
const float scale = 1.0f / sqrtf((float)S_v);
const vk_op_gated_delta_net_push_constants pc = {
H, n_tokens, n_seqs, s_off,
sq1, sq2, sq3,
sv1, sv2, sv3,
sb1, sb2, sb3,
neq1, rq3,
scale
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -12682,7 +12800,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
ctx->query_node_idx[ctx->query_idx] = node_idx;
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
}
}
// Add all fused nodes to the unsynchronized lists.
@@ -13024,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_GATED_DELTA_NET:
ggml_vk_gated_delta_net(ctx, compute_ctx, node);
break;
case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node);
@@ -13521,7 +13644,7 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
buffer_cpy.dstOffset = dst_offset;
buffer_cpy.size = size;
cpy_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
ggml_vk_synchronize(ctx);
}
@@ -13555,7 +13678,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
buffer_cpy.dstOffset = 0;
buffer_cpy.size = size;
compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
ggml_vk_synchronize(ctx);
}
@@ -13633,8 +13756,12 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
}
vk_context compute_ctx;
vk_command_buffer* cmd_buf = nullptr;
if (do_transfer) {
compute_ctx = ctx->compute_ctx.lock();
if (compute_ctx->s) {
cmd_buf = compute_ctx->s->buffer;
}
ggml_vk_ctx_end(compute_ctx);
@@ -13668,6 +13795,9 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
}
ggml_vk_wait_for_fence(ctx);
ctx->submit_pending = false;
if (cmd_buf) {
cmd_buf->in_use = false;
}
}
if (do_transfer) {
@@ -14157,7 +14287,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
GGML_ASSERT(ctx->compute_ctx.expired());
compute_ctx = ggml_vk_get_compute_ctx(ctx);
ctx->query_idx = 0;
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
}
ctx->prealloc_y_last_pipeline_used = nullptr;
@@ -14393,7 +14523,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
// track a single node/fusion for the current query
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
} else {
// track a fusion string and number of fused ops for the current node_idx
ctx->query_fusion_names[i] = fusion_string;
@@ -14726,6 +14856,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
ggml_vk_submit_transfer_ctx(ctx);
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset
// the backend interface doesn't have an explicit reset, so reset it here
// before we record the command to set it
@@ -14738,6 +14869,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
ggml_vk_submit(compute_ctx, {vkev->fence});
ctx->submit_pending = true;
vkev->cmd_buffer = cmd_buf;
ctx->compute_ctx.reset();
}
@@ -15426,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true; // all inputs are contiguous, see ggml.c
case GGML_OP_GATED_DELTA_NET:
{
const uint32_t S_v = op->src[2]->ne[0];
if (S_v != 32 && S_v != 64 && S_v != 128) {
return false;
}
for (int i = 0; i < 6; i++) {
if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {
return false;
}
}
return op->type == GGML_TYPE_F32;
}
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
@@ -15557,6 +15702,10 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
vk_event *vkev = (vk_event *)event->context;
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
// Finished using current command buffer so we flag for reuse
if (vkev->cmd_buffer) {
vkev->cmd_buffer->in_use = false;
}
}
static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
@@ -16028,7 +16177,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
} else if (tensor->op == GGML_OP_FILL) {
const float value = ggml_get_op_params_f32(tensor, 0);
tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
tensor_clone = ggml_fill(ggml_ctx, src_clone[0], value);
} else if (tensor->op == GGML_OP_SQR) {
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SQRT) {
@@ -16299,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
@@ -33,6 +33,61 @@ layout (push_constant) uniform parameter {
shared float minsh[NUM_SUBGROUPS];
shared float maxsh[NUM_SUBGROUPS];
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
const uint tid = gl_LocalInvocationIndex;
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
float min_v = FLT_MAX_OVER_2;
float max_v = -FLT_MAX_OVER_2;
[[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
uint j0 = (i + tid) % (Bc / 4);
uint j1 = (i + tid) / (Bc / 4);
j0 *= 4;
j0 += (i0 * 16 + block_x) * Bc;
j1 += i1 * Br;
if (!need_bounds_check || j0 + 3 < nem0) {
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
[[unroll]] for (int c = 0; c < 4; ++c) {
min_v = min(min_v, f[c]);
max_v = max(max_v, f[c]);
}
} else {
[[unroll]] for (int c = 0; c < 4; ++c) {
if (j0 + c < nem0) {
float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
min_v = min(min_v, f);
max_v = max(max_v, f);
}
}
}
}
min_v = subgroupMin(min_v);
max_v = subgroupMax(max_v);
if (gl_SubgroupInvocationID == 0) {
minsh[gl_SubgroupID] = min_v;
maxsh[gl_SubgroupID] = max_v;
}
barrier();
if (tid == 0) {
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
min_v = min(min_v, minsh[i]);
max_v = max(max_v, maxsh[i]);
}
if (max_v <= -FLT_MAX_OVER_2) {
result |= 1 << (2*block_x);
}
if (min_v == 0.0f && max_v == 0.0f) {
result |= 2 << (2*block_x);
}
}
barrier();
}
}
// For each Br x Bc block of the mask (input) buffer, read all values and check
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
@@ -48,50 +103,15 @@ void main() {
const uint i2 = gl_WorkGroupID.z % nem2;
const uint i3 = gl_WorkGroupID.z / nem2;
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
uint result = 0;
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
float min_v = FLT_MAX_OVER_2;
float max_v = -FLT_MAX_OVER_2;
[[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
uint j0 = (i + tid) % (Bc / 4);
uint j1 = (i + tid) / (Bc / 4);
j0 *= 4;
j0 += (i0 * 16 + block_x) * Bc;
j1 += i1 * Br;
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
[[unroll]] for (int c = 0; c < 4; ++c) {
min_v = min(min_v, f[c]);
max_v = max(max_v, f[c]);
}
}
min_v = subgroupMin(min_v);
max_v = subgroupMax(max_v);
if (gl_SubgroupInvocationID == 0) {
minsh[gl_SubgroupID] = min_v;
maxsh[gl_SubgroupID] = max_v;
}
barrier();
if (tid == 0) {
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
min_v = min(min_v, minsh[i]);
max_v = max(max_v, maxsh[i]);
}
if (max_v <= -FLT_MAX_OVER_2) {
result |= 1 << (2*block_x);
}
if (min_v == 0.0f && max_v == 0.0f) {
result |= 2 << (2*block_x);
}
}
barrier();
if ((i0 + 1) * 16 * Bc <= nem0) {
loadvec4(result, i0, i1, i2, i3, false);
} else {
loadvec4(result, i0, i1, i2, i3, true);
}
} else {
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
@@ -0,0 +1,128 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint s_off;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
float scale;
};
layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; };
layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; };
layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; };
layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; };
layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; };
layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; };
layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; };
shared FLOAT_TYPE s_k[S_V];
shared FLOAT_TYPE s_q[S_V];
shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
FLOAT_TYPE state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
}
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
for (uint t = 0; t < n_tokens; t++) {
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = q_off;
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
if (KDA != 0) {
const uint g_base = gb_off * S_V;
s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
}
barrier();
const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
if (KDA == 0) {
const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
kv_col += dot(
vec4(state[i], state[i+1], state[i+2], state[i+3]),
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
);
}
FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = g_val * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
data_dst[attn_off + col] = attn_col * scale;
} else {
FLOAT_TYPE kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
kv_col += dot(gv * sv, kv);
}
FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
FLOAT_TYPE attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
sv = gv * sv + kv * delta_col;
state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
}
data_dst[attn_off + col] = attn_col * scale;
}
attn_off += S_V * H;
barrier();
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
data_dst[s_off + state_base + i * S_V + col] = state[i];
}
}
@@ -36,7 +36,7 @@ void main() {
barrier();
}
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
@@ -5,8 +5,9 @@
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
@@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants {
};
void main() {
const uint global_thread_id = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;
const uint i3 = gl_WorkGroupID.z;
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
if (i1 >= nr || i2 >= n_t || i3 >= n_s) {
return;
}
const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];
if (nc == 4) {
sum = dot(
vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),
vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])
);
} else {
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
sum += src0[src0_base + i0] * src1[src1_base + i0];
}
}
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
dst[dst_idx] = sum;
}
@@ -987,6 +987,8 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+124 -43
View File
@@ -42,11 +42,20 @@
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
// Requires at least two (and multiple of 2) k-quant blocks per tile
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
@@ -189,6 +198,22 @@ struct ggml_webgpu_concat_pipeline_key_hash {
}
};
/** Repeat **/
struct ggml_webgpu_repeat_pipeline_key {
int type;
bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
};
struct ggml_webgpu_repeat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
return seed;
}
};
/** Binary **/
struct ggml_webgpu_binary_pipeline_key {
@@ -199,7 +224,8 @@ struct ggml_webgpu_binary_pipeline_key {
bool src_overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
src_overlap == other.src_overlap;
}
};
@@ -421,6 +447,8 @@ class ggml_webgpu_shader_lib {
binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
concat_pipelines; // type
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@@ -749,6 +777,36 @@ class ggml_webgpu_shader_lib {
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
variant += "_f16";
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
variant += "_" + src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
break;
}
}
// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
@@ -763,39 +821,21 @@ class ggml_webgpu_shader_lib {
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
}
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
break;
case GGML_TYPE_F16:
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
break;
default:
{
// Quantized types: use helpers but accumulate in f16
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
std::string src0_name = src0_traits->type_name;
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
break;
}
}
// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
if (key.src0_type >= GGML_TYPE_Q2_K) {
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
@@ -1061,10 +1101,10 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.src_overlap = context.src_overlap,
};
@@ -1125,7 +1165,7 @@ class ggml_webgpu_shader_lib {
}
std::vector<std::string> defines;
std::string variant = "concat";
std::string variant = "concat";
switch (key.type) {
case GGML_TYPE_F32:
@@ -1142,15 +1182,56 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
concat_pipelines[key] = pipeline;
pipeline.context = decisions;
concat_pipelines[key] = pipeline;
return concat_pipelines[key];
}
webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_repeat_pipeline_key key = {
.type = context.dst->type,
};
auto it = repeat_pipelines.find(key);
if (it != repeat_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "repeat";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_I32:
defines.push_back("TYPE_I32");
variant += "_i32";
break;
case GGML_TYPE_I16:
defines.push_back("TYPE_I16");
variant += "_i16";
break;
default:
GGML_ABORT("Unsupported type for repeat shader");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_repeat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
repeat_pipelines[key] = pipeline;
return repeat_pipelines[key];
}
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;
+285 -229
View File
@@ -8,7 +8,6 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "pre_wgsl.hpp"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
@@ -20,12 +19,18 @@
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <iostream>
#ifdef GGML_WEBGPU_GPU_PROFILE
# include <iomanip>
#endif
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
# include <iostream>
#endif
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
@@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#endif // GGML_WEBGPU_CPU_PROFILE
#ifdef GGML_WEBGPU_GPU_PROFILE
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
#endif
/* Constants */
#define WEBGPU_NUM_PARAM_BUFS 48u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
#define WEBGPU_NUM_PARAM_BUFS 96u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable
// default
@@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::BufferUsage usage,
const char * label);
struct webgpu_pool_bufs {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
};
// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;
std::vector<wgpu::Buffer> free;
// The pool must be synchronized because
// 1. The memset pool is shared globally by every ggml buffer,
@@ -138,7 +137,6 @@ struct webgpu_buf_pool {
size_t cur_pool_size;
size_t max_pool_size;
wgpu::Device device;
wgpu::BufferUsage host_buf_usage;
wgpu::BufferUsage dev_buf_usage;
size_t buf_size;
bool should_grow;
@@ -147,53 +145,47 @@ struct webgpu_buf_pool {
int num_bufs,
size_t buf_size,
wgpu::BufferUsage dev_buf_usage,
wgpu::BufferUsage host_buf_usage,
bool should_grow = false,
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
this->max_pool_size = max_pool_size;
this->cur_pool_size = num_bufs;
this->device = device;
this->host_buf_usage = host_buf_usage;
this->dev_buf_usage = dev_buf_usage;
this->buf_size = buf_size;
this->should_grow = should_grow;
this->max_pool_size = max_pool_size;
this->cur_pool_size = num_bufs;
this->device = device;
this->dev_buf_usage = dev_buf_usage;
this->buf_size = buf_size;
this->should_grow = should_grow;
for (int i = 0; i < num_bufs; i++) {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
free.push_back({ host_buf, dev_buf });
free.push_back(dev_buf);
}
}
webgpu_pool_bufs alloc_bufs() {
wgpu::Buffer alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex);
if (!free.empty()) {
webgpu_pool_bufs bufs = free.back();
wgpu::Buffer buf = free.back();
free.pop_back();
return bufs;
return buf;
}
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
if (!(host_buf && dev_buf)) {
if (!dev_buf) {
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
}
return webgpu_pool_bufs{ host_buf, dev_buf };
return dev_buf;
}
cv.wait(lock, [this] { return !free.empty(); });
webgpu_pool_bufs bufs = free.back();
wgpu::Buffer buf = free.back();
free.pop_back();
return bufs;
return buf;
}
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
void free_bufs(std::vector<wgpu::Buffer> bufs) {
std::lock_guard<std::mutex> lock(mutex);
free.insert(free.end(), bufs.begin(), bufs.end());
cv.notify_all();
@@ -201,12 +193,9 @@ struct webgpu_buf_pool {
void cleanup() {
std::lock_guard<std::mutex> lock(mutex);
for (auto & bufs : free) {
if (bufs.host_buf) {
bufs.host_buf.Destroy();
}
if (bufs.dev_buf) {
bufs.dev_buf.Destroy();
for (auto & buf : free) {
if (buf) {
buf.Destroy();
}
}
free.clear();
@@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool {
#endif
struct webgpu_command {
uint32_t num_kernels;
wgpu::CommandBuffer commands;
std::vector<webgpu_pool_bufs> params_bufs;
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
uint32_t num_kernels;
wgpu::CommandBuffer commands;
std::vector<wgpu::Buffer> params_bufs;
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs timestamp_query_bufs;
std::string pipeline_name;
@@ -358,6 +346,13 @@ struct webgpu_global_context_struct {
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
struct webgpu_submission {
wgpu::FutureWaitInfo submit_done;
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<wgpu::FutureWaitInfo> profile_futures;
#endif
};
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
// Points to global instances owned by ggml_backend_webgpu_reg_context
@@ -366,7 +361,8 @@ struct webgpu_context_struct {
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
wgpu::Buffer set_rows_dev_error_buf;
wgpu::Buffer set_rows_host_error_buf;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
@@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
/** End WebGPU object initializations */
/** WebGPU Actions */
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
switch (status) {
case wgpu::WaitStatus::Success:
return true;
case wgpu::WaitStatus::TimedOut:
if (allow_timeout) {
return false;
}
GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
return false;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
return false;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
return false;
}
}
#ifdef GGML_WEBGPU_GPU_PROFILE
static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
futures.erase(std::remove_if(futures.begin(), futures.end(),
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
futures.end());
}
// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block) {
if (futures.empty()) {
return;
}
uint64_t timeout_ms = block ? UINT64_MAX : 0;
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
if (waitStatus == wgpu::WaitStatus::Error) {
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
ggml_backend_webgpu_erase_completed_futures(futures);
}
}
if (futures[0].completed) {
futures.erase(futures.begin());
} else {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
ggml_backend_webgpu_erase_completed_futures(futures);
}
}
}
#endif
// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<webgpu_submission> & subs,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
if (subs.empty()) {
return;
}
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
#endif
subs.erase(subs.begin());
}
}
if (futures.empty()) {
if (subs.empty()) {
return;
}
if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
for (auto & sub : subs) {
while (!sub.submit_done.completed) {
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
ggml_backend_webgpu_handle_wait_status(waitStatus);
}
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
#endif
}
subs.clear();
} else {
// Poll once and return
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::TimedOut:
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
// Poll each submit future once and remove completed submissions.
for (auto sub = subs.begin(); sub != subs.end();) {
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
if (sub->submit_done.completed && sub->profile_futures.empty()) {
#else
if (sub->submit_done.completed) {
#endif
sub = subs.erase(sub);
} else {
++sub;
}
}
}
}
@@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
}
#endif
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
webgpu_global_context ctx,
std::vector<webgpu_command> commands,
webgpu_buf_pool & param_buf_pool,
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
std::vector<webgpu_command> & commands,
webgpu_buf_pool & param_buf_pool) {
std::vector<wgpu::CommandBuffer> command_buffers;
std::vector<webgpu_pool_bufs> params_bufs;
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
std::vector<wgpu::Buffer> params_bufs;
webgpu_submission submission;
#ifdef GGML_WEBGPU_GPU_PROFILE
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
#endif
@@ -569,14 +601,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
for (const auto & command : commands) {
command_buffers.push_back(command.commands);
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
if (command.set_rows_error_bufs) {
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
}
}
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
std::vector<wgpu::FutureWaitInfo> futures;
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
@@ -586,27 +613,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
// Free the staged buffers
param_buf_pool.free_bufs(params_bufs);
});
futures.push_back({ p_f });
for (const auto & bufs : set_rows_error_bufs) {
wgpu::Future f = bufs.host_buf.MapAsync(
wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
[set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
} else {
const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
// We can't unmap in here due to WebGPU reentrancy limitations.
if (set_rows_error_buf_pool) {
set_rows_error_buf_pool->free_bufs({ bufs });
}
}
});
futures.push_back({ f });
}
submission.submit_done = { p_f };
#ifdef GGML_WEBGPU_GPU_PROFILE
for (const auto & command : commands) {
@@ -623,14 +630,14 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
// WebGPU timestamps are in ns; convert to ms
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
ctx->shader_gpu_time_ms[label] += elapsed_ms;
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
}
// We can't unmap in here due to WebGPU reentrancy limitations.
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
});
futures.push_back({ f });
submission.profile_futures.push_back({ f });
}
#endif
return futures;
return submission;
}
static webgpu_command ggml_backend_webgpu_build_multi(
@@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
const std::vector<webgpu_pipeline> & pipelines,
const std::vector<std::vector<uint32_t>> & params_list,
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
GGML_ASSERT(pipelines.size() == params_list.size());
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
GGML_ASSERT(pipelines.size() == workgroups_list.size());
std::vector<webgpu_pool_bufs> params_bufs_list;
std::vector<wgpu::BindGroup> bind_groups;
std::vector<wgpu::Buffer> params_bufs_list;
std::vector<wgpu::BindGroup> bind_groups;
for (size_t i = 0; i < pipelines.size(); i++) {
webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
params_bufs.host_buf.GetSize());
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
for (size_t j = 0; j < params_list[i].size(); j++) {
_params[j] = params_list[i][j];
}
params_bufs.host_buf.Unmap();
wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
uint32_t params_binding_num = entries.size();
entries.push_back({ .binding = params_binding_num,
.buffer = params_bufs.dev_buf,
.offset = 0,
.size = params_bufs.dev_buf.GetSize() });
entries.push_back(
{ .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
@@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
}
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
for (const auto & params_bufs : params_bufs_list) {
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
}
// If there are SET_ROWS operations in this submission, copy their error
// buffers to the host.
if (set_rows_error_bufs) {
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
set_rows_error_bufs->host_buf.GetSize());
for (size_t i = 0; i < params_bufs_list.size(); i++) {
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
}
#ifdef GGML_WEBGPU_GPU_PROFILE
@@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
webgpu_command result = {};
result.commands = commands;
result.params_bufs = params_bufs_list;
result.set_rows_error_bufs = set_rows_error_bufs;
result.num_kernels = pipelines.size();
#ifdef GGML_WEBGPU_GPU_PROFILE
result.timestamp_query_bufs = ts_bufs;
@@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
uint32_t wg_y = 1,
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
uint32_t wg_y = 1) {
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
{
pipeline
},
{ params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
{ std::move(params) }, { std::move(bind_group_entries) },
{ { wg_x, wg_y } });
}
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
@@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
webgpu_command command =
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
ggml_backend_webgpu_wait(ctx, futures);
std::vector<webgpu_command> commands = { command };
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
@@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
std::cout << "\nggml_webgpu: gpu breakdown:\n";
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
<< pct << "%)\n";
}
#endif
@@ -978,14 +968,6 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
if (decisions->i64_idx) {
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
error_bufs->host_buf.Unmap();
}
}
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
@@ -1018,8 +1000,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
};
if (decisions->i64_idx) {
entries.push_back(
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
entries.push_back({ .binding = 3,
.buffer = ctx->set_rows_dev_error_buf,
.offset = 0,
.size = ctx->set_rows_dev_error_buf.GetSize() });
}
uint32_t threads;
@@ -1029,8 +1013,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
error_bufs);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
}
// Workgroup size is a common constant
@@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
use_fast = (src0->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
use_fast = true;
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
use_fast = !is_vec;
break;
default:
break;
}
@@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else if (use_fast) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Fast-path tiled/subgroup calculations
uint32_t wg_m, wg_n;
uint32_t wg_m;
uint32_t wg_n;
if (decisions->use_subgroup_matrix) {
uint32_t wg_m_sg_tile =
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
@@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else { // legacy
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
@@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
}
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
@@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t)src0->ne[dim]
(uint32_t) src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
{
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
},
{
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
},
{
.binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
}
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -1569,9 +1561,51 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = { ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src0->ne[0]),
(uint32_t) (src0->ne[1]),
(uint32_t) (src0->ne[2]),
(uint32_t) (src0->ne[3]),
(uint32_t) (dst->ne[0]),
(uint32_t) (dst->ne[1]),
(uint32_t) (dst->ne[2]) };
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@@ -1623,7 +1657,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2161,6 +2200,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return ggml_webgpu_binary_op(ctx, src0, src1, node);
case GGML_OP_CONCAT:
return ggml_webgpu_concat(ctx, src0, src1, node);
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
@@ -2172,19 +2213,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_SOFT_MAX:
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
case GGML_OP_UNARY:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_CLAMP:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_FILL:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_LOG:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQR:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQRT:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SIN:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_COS:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_PAD:
@@ -2192,7 +2226,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_ARGMAX:
return ggml_webgpu_argmax(ctx, src0, node);
case GGML_OP_ARGSORT:
return ggml_webgpu_argsort(ctx, src0, node);
case GGML_OP_TOP_K:
// we reuse the same argsort implementation for top_k
return ggml_webgpu_argsort(ctx, src0, node);
@@ -2214,33 +2247,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
std::vector<webgpu_command> commands;
std::vector<wgpu::FutureWaitInfo> futures;
uint32_t num_batched_kernels = 0;
std::vector<webgpu_command> commands;
std::vector<webgpu_submission> subs;
uint32_t num_batched_kernels = 0;
bool contains_set_rows = false;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
}
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
}
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
num_batched_kernels = 0;
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
num_batched_kernels = 0;
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
// Process events and check for completed submissions
ctx->global_ctx->instance.ProcessEvents();
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
commands.clear();
}
}
if (!commands.empty()) {
auto new_futures =
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
commands.clear();
}
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
// If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
if (contains_set_rows) {
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
ctx->set_rows_host_error_buf.GetSize());
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
ctx->set_rows_host_error_buf.GetSize());
const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
if (*error_data) {
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
}
ctx->set_rows_host_error_buf.Unmap();
}
ggml_backend_webgpu_wait(ctx->global_ctx, subs);
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
return GGML_STATUS_SUCCESS;
}
@@ -2859,10 +2910,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
@@ -2910,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
/* .iface = */ {
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
/* .alloc_buffer = */
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
},
/* .device = */
dev,
@@ -2991,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_CONCAT:
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
break;
case GGML_OP_REPEAT:
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -11,7 +11,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
#endif
#endif // VEC
#ifdef SCALAR
#define VEC_SIZE 1
@@ -23,7 +23,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
#endif
#endif // SCALAR
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
@@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
}
}
#endif
#endif // INIT_SRC0_SHMEM_FLOAT
#ifdef INIT_SRC1_SHMEM_FLOAT
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
@@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}
#endif
#endif // INIT_SRC1_SHMEM_FLOAT
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
@@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
}
#endif
#endif // INIT_SRC0_SHMEM_Q4_0
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
shmem[shmem_idx + j * 2 + k] = q_lo;
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_1
#ifdef INIT_SRC0_SHMEM_Q5_0
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_0
#ifdef INIT_SRC0_SHMEM_Q5_1
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_1
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
shmem[shmem_idx + j * 2 + k] = q_val;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_0
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d + m;
shmem[shmem_idx + j * 2 + k] = q_val;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_1
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 42u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx + 40u];
let dmin = src0[scale_idx + 41u];
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
let pos_in_32 = k_in_block % 32u;
let q_b_idx = (block_of_32 / 4u) * 32u;
let shift = (block_of_32 % 4u) * 2u;
let k = (pos_in_32 / 16u) * 16u;
let l = pos_in_32 % 16u;
let is = k_in_block / 16u;
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q2_K
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 55u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx + 54u];
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
let kmask2: u32 = 0x0f0f0f0fu;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
let scale_0 = src0[scale_idx + 48u + (2u*i)];
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
let hmask_0 = src0[scale_idx + (2u*i)];
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
let qs_0 = src0[scale_idx + 16u + (2u*i)];
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
}
let half = k_in_block / 128u; // 0 or 1
let pos_in_half = k_in_block % 128u; // 0-127
let shift_group = pos_in_half / 32u; // 0-3
let pos_in_32 = pos_in_half % 32u; // 0-31
let k_group = pos_in_32 / 16u; // 0 or 1
let l = pos_in_32 % 16u; // 0-15
let q_b_idx = half * 32u; // 0 or 32
let shift = shift_group * 2u; // 0, 2, 4, 6
let k = k_group * 16u; // 0 or 16
let is = k_in_block / 16u; // 0-15
// m increments every 32 elements across entire 256 element block
let m_shift = k_in_block / 32u; // 0-7
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
let sc = get_byte(scale_vals[is / 4u], is % 4u);
let dl = d * (f16(sc) - 32.0);
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3u;
let q_val = (f16(qs_val) - f16(hm)) * dl;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q3_K
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 72u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
// Inner loop over 2 shifts per group
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
var sc: u32;
var mn: u32;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q4_K
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 88u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
}
// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
// But u increments EVERY 32 elements (after each l loop)
let group_of_64 = k_in_block / 64u; // 0-3
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
let u_shift = k_in_block / 32u; // 0-7
let u: u32 = 1u << u_shift;
var sc: u32;
var mn: u32;
if (is < 4u) {
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
let qh_byte = get_byte(qh_packed, l % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q5_K
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
let quarter = pos_in_half / 32u;
let l = pos_in_half % 32u;
let ql_b_idx = half * 64u;
let qh_b_idx = half * 32u;
let sc_b_idx = half * 8u;
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13_word = ql13_flat / 4u;
let ql13 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql13_word],
src0[scale_idx + 2u * ql13_word + 1u]
));
let ql13_b = get_byte(ql13, ql13_flat % 4u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24_word = ql24_flat / 4u;
let ql24 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql24_word],
src0[scale_idx + 2u * ql24_word + 1u]
));
let ql24_b = get_byte(ql24, ql24_flat % 4u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh_word = qh_flat / 4u;
let qh = bitcast<u32>(vec2(
src0[scale_idx + 64u + 2u * qh_word],
src0[scale_idx + 64u + 2u * qh_word + 1u]
));
let qh_b = get_byte(qh, qh_flat % 4u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc_word = sc_idx / 4u;
let sc = bitcast<u32>(vec2(
src0[scale_idx + 96u + 2u * sc_word],
src0[scale_idx + 96u + 2u * sc_word + 1u]
));
let sc_val = get_byte_i32(sc, sc_idx % 4u);
let d = src0[scale_idx + 104u];
var q_val: f16;
if (quarter == 0u) {
q_val = q1;
} else if (quarter == 1u) {
q_val = q2;
} else if (quarter == 2u) {
q_val = q3;
} else {
q_val = q4;
}
shmem[elem_idx] = d * f16(sc_val) * q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q6_K
@@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 {
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
@@ -1,4 +1,3 @@
enable f16;
#include "common_decls.tmpl"
@@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
}
#endif
#ifdef MUL_ACC_Q4_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 10u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = f32(src0[scale_idx + 1u]);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
let q_lo = f32(q_byte & 0xF) * d + m;
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q5_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 11u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q5_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 12u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let j_adjusted = j + (block_offset / 2u);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q8_0
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 17u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q8_1
const BLOCK_SIZE = 32;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 18u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
}
}
}
return local_sum;
}
#endif
#ifdef MUL_ACC_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
let aligned = byte_offset & ~3u;
let idx = bbase + aligned / 2u;
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
}
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
}
fn sbyte_of(v: u32, b: u32) -> i32 {
let raw = i32((v >> (b * 8u)) & 0xFFu);
return select(raw, raw - 256, raw >= 128);
}
fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let tid = tig / 2u;
let ix = tig % 2u;
let ip = tid / 8u;
let il = tid % 8u;
let l0 = 4u * il;
let is = 8u * ip + l0 / 16u;
let y_offset = 128u * ip + l0;
let q_offset_l = 64u * ip + l0;
let q_offset_h = 32u * ip + l0;
let nb = tile_size / BLOCK_SIZE;
let k_block_start = k_outer / BLOCK_SIZE;
// Aligned scale byte position (is can be odd)
let sc_base_byte = 192u + (is & ~3u);
let sc_byte_pos = is & 3u;
var local_sum = 0.0;
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
let d_raw = load_u32_at(bbase, 208u);
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
let ql1_u32 = load_u32_at(bbase, q_offset_l);
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
for (var l = 0u; l < 4u; l++) {
let y_base = i * BLOCK_SIZE + y_offset + l;
let yl0 = f32(shared_vector[y_base]);
let yl1 = f32(shared_vector[y_base + 32u]);
let yl2 = f32(shared_vector[y_base + 64u]);
let yl3 = f32(shared_vector[y_base + 96u]);
let q1b = byte_of(ql1_u32, l);
let q2b = byte_of(ql2_u32, l);
let qhb = byte_of(qh_u32, l);
let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32);
let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
sums[0] += yl0 * dq0;
sums[1] += yl1 * dq1;
sums[2] += yl2 * dq2;
sums[3] += yl3 * dq3;
}
local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
sums[2] * f32(sc4) + sums[3] * f32(sc6));
}
return local_sum;
}
#endif
struct MulMatParams {
offset_src0: u32,
offset_src1: u32,
@@ -191,4 +478,3 @@ fn main(
dst[dst_idx / VEC_SIZE] = store_val(group_base);
}
}
@@ -0,0 +1,67 @@
enable f16;
struct Params {
ne: u32,
offset_src0: u32,
offset_dst: u32,
stride_src0_0: u32,
stride_src0_1: u32,
stride_src0_2: u32,
stride_src0_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
a_ne3: u32,
ne0: u32,
ne1: u32,
ne2: u32,
};
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_I32
#define DataType i32
#endif
#ifdef TYPE_I16
// same size (16-bit) is sufficient for repeat
#define DataType f16
#endif
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let a_i0 = i0 % params.a_ne0;
let a_i1 = i1 % params.a_ne1;
let a_i2 = i2 % params.a_ne2;
let a_i3 = i3 % params.a_ne3;
let a_index = a_i0 * params.stride_src0_0 +
a_i1 * params.stride_src0_1 +
a_i2 * params.stride_src0_2 +
a_i3 * params.stride_src0_3;
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index];
}
}
+10
View File
@@ -718,6 +718,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_mxfp4,
.from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
},
[GGML_TYPE_NVFP4] = {
.type_name = "nvfp4",
.blck_size = QK_NVFP4,
.type_size = sizeof(block_nvfp4),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_nvfp4,
.from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref,
},
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
@@ -1374,6 +1382,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
@@ -7641,6 +7650,7 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+11
View File
@@ -125,6 +125,7 @@ class Keys:
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
MOE_LATENT_SIZE = "{arch}.moe_latent_size"
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers"
POOLING_TYPE = "{arch}.pooling_type"
@@ -543,6 +544,8 @@ class MODEL_TENSOR(IntEnum):
FFN_DOWN_CHEXP = auto()
FFN_UP_CHEXP = auto()
FFN_EXP_PROBS_B = auto()
MOE_LATENT_DOWN = auto() # nemotron 3 super
MOE_LATENT_UP = auto() # nemotron 3 super
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
@@ -986,6 +989,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
@@ -2913,6 +2918,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
# expert latent
MODEL_TENSOR.MOE_LATENT_DOWN,
MODEL_TENSOR.MOE_LATENT_UP,
# shared expert
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
@@ -3776,6 +3784,7 @@ class GGMLQuantizationType(IntEnum):
TQ1_0 = 34
TQ2_0 = 35
MXFP4 = 39
NVFP4 = 40
class ExpertGatingFuncType(IntEnum):
@@ -3872,6 +3881,7 @@ class VisionProjectorType:
GEMMA3 = "gemma3"
GEMMA3NV = "gemma3nv"
GEMMA3NA = "gemma3na"
PHI4 = "phi4"
IDEFICS3 = "idefics3"
PIXTRAL = "pixtral"
LLAMA4 = "llama4"
@@ -3933,6 +3943,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
GGMLQuantizationType.MXFP4: (32, 1 + 16),
GGMLQuantizationType.NVFP4: (64, 4 + 32),
}
+10 -4
View File
@@ -139,10 +139,13 @@ class GGUFWriter:
size = prod(shape)
if "_exps." in name:
expert_count = shape[-2 if ".bias" in name else -3]
expert_params += (size // expert_count)
expert_sum += expert_count
n_expert_tensors += 1
if len(shape) >= 3:
expert_count = shape[-2 if ".bias" in name else -3]
expert_params += (size // expert_count)
expert_sum += expert_count
n_expert_tensors += 1
else:
shared_params += size
else:
shared_params += size
@@ -859,6 +862,9 @@ class GGUFWriter:
def add_moe_every_n_layers(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
def add_moe_latent_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_LATENT_SIZE.format(arch=self.arch), value)
def add_nextn_predict_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
+59
View File
@@ -704,6 +704,65 @@ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
return (d * qs.astype(np.float32))
class NVFP4(__Quant, qtype=GGMLQuantizationType.NVFP4):
# E2M1 values doubled (kvalues_mxfp4 convention)
kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
@staticmethod
def ue4m3_to_fp32(x: np.ndarray) -> np.ndarray:
"""Decode unsigned E4M3 (bias=7) to float, with 0.5 factor for kvalues convention."""
exp = (x >> 3).astype(np.int32) & 0xF
man = (x & 0x7).astype(np.float32)
raw = np.where(
exp == 0,
man * 2**-9,
(1.0 + man / 8.0) * (2.0 ** (exp.astype(np.float32) - 7)))
return np.where((x == 0) | (x == 0x7F), 0.0, raw * 0.5)
@staticmethod
def fp32_to_ue4m3(x: np.ndarray) -> np.ndarray:
"""Vectorized float32 to unsigned E4M3, matching ggml_fp32_to_ue4m3 in C."""
x = np.clip(x, 0.0, 448.0).astype(np.float32)
bits = x.view(np.uint32)
fp32_exp = ((bits >> 23) & 0xFF).astype(np.int32) - 127
fp32_man = ((bits >> 20) & 0x7).astype(np.int32)
ue4m3_exp = fp32_exp + 7
# Subnormal
sub_man = np.clip((x * 512.0 + 0.5).astype(np.int32), 0, 7)
sub_result = np.where(sub_man >= 1, sub_man, 0).astype(np.uint8)
# Normal with rounding
round_bit = ((bits >> 19) & 1).astype(np.int32)
man = fp32_man + round_bit
exp = ue4m3_exp.copy()
overflow = man > 7
man = np.where(overflow, 0, man)
exp = np.where(overflow, exp + 1, exp)
normal_result = np.where(exp >= 15, np.uint8(0x7E), ((exp << 3) | man).astype(np.uint8))
return np.where(x <= 0.0, np.uint8(0),
np.where(ue4m3_exp <= 0, sub_result,
np.where(ue4m3_exp >= 15, np.uint8(0x7E), normal_result)))
@classmethod
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n_super = blocks.shape[0]
d_bytes, qs = np.hsplit(blocks, [4])
d = cls.ue4m3_to_fp32(d_bytes).reshape(n_super, 4, 1) # (n_super, 4, 1)
qs = qs.reshape(n_super, 4, 8)
lo = (qs & np.uint8(0x0F)).view(np.int8)
hi = (qs >> np.uint8(4)).view(np.int8)
vals = np.concatenate([lo, hi], axis=-1) # (n_super, 4, 16)
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
vals = np.take_along_axis(kvalues, vals, axis=-1)
return (d * vals.astype(np.float32)).reshape(n_super, 64)
class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
ksigns: bytes = (
b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
@@ -65,6 +65,7 @@ byteswap_tensors = {
gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k,
gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k,
gguf.GGMLQuantizationType.MXFP4: byteswap_noop,
gguf.GGMLQuantizationType.NVFP4: byteswap_noop,
}
+8
View File
@@ -571,6 +571,14 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.gate_up_proj",
),
MODEL_TENSOR.MOE_LATENT_DOWN: (
"backbone.layers.{bid}.mixer.fc1_latent_proj", # nemotron 3 super
),
MODEL_TENSOR.MOE_LATENT_UP: (
"backbone.layers.{bid}.mixer.fc2_latent_proj", # nemotron 3 super
),
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
+1
View File
@@ -68,6 +68,7 @@ class GGMLQuants:
"q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
"tq1_0", "tq2_0",
"mxfp4",
"nvfp4",
"iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
"iq4_nl", "iq4_xs",
):
+1
View File
@@ -153,6 +153,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors
LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
+355
View File
@@ -0,0 +1,355 @@
{#--------TOOL RENDERING FUNCTIONS---------#}
{#---------------------------------------------------------------
Converts JSON Schema (dict) to a TypeScript type definition
----------------------------------------------------------------#}
{%- macro json_schema_to_typescript(schema, indent="") -%}
{%- set ADDITIONAL_JSON_KEYS = ['format', 'maxItems', 'maximum', 'minItems', 'minimum', 'pattern'] -%}
{%- set ty = schema.get("type") -%}
{# ---------------- OBJECT ---------------- #}
{%- if ty == "object" -%}
{{- "{\n" -}}
{# Start building property list #}
{%- set props = schema.get("properties", {}) -%}
{%- set required = schema.get("required", []) -%}
{%- set has_additional_props = schema.get("additionalProperties") is defined -%}
{%- set additional_props_type = none -%}
{%- if has_additional_props -%}
{%- if schema.additionalProperties == true -%}
{%- set additional_props_type = {'type': 'any'} -%}
{%- elif schema.additionalProperties is mapping -%}
{%- set additional_props_type = schema.additionalProperties -%}
{%- endif -%}
{%- endif -%}
{%- for key, val in props.items() -%}
{# ---------- Description Comments ---------- #}
{%- if "description" in val -%}
{%- for line in val['description'].split('\n') -%}
{%- if line.strip() -%}
{{- indent + '// ' + line + '\n' -}}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{# ---------- Additional JSON Keys ---------- #}
{%- for add_key, add_val in val.items() -%}
{%- if add_key in ADDITIONAL_JSON_KEYS -%}
{%- if add_val is string -%}
{{- indent + '// ' + add_key + ': "' + add_val + '"' + '\n' -}}
{%- else -%}
{{- indent + '// ' + add_key + ': ' ~ add_val ~ '\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{# ---------- Property Definition ---------- #}
{%- set type_str = json_schema_to_typescript(
val,
indent + " "
) -%}
{{- indent + key + ('' if key in required else '?') + ': ' + type_str + ',' -}}
{%- if "default" in val or "defalut_value" in val -%}
{%- set default = val.get("default", val.get("defalut_value")) -%}
{%- if default is string -%}
{{- ' // default: "' + default + '"' -}}
{%- else -%}
{{- ' // default: ' ~ default -}}
{%- endif -%}
{%- endif -%}
{{- "\n" -}}
{%- endfor -%}
{# Handle additionalProperties as index signature #}
{%- if has_additional_props and additional_props_type is not none -%}
{%- set additional_type_str = json_schema_to_typescript(
additional_props_type,
indent + " "
) -%}
{{- indent + '[key: string]: ' + additional_type_str + '\n' -}}
{%- endif -%}
{{- indent[: (indent|length - " "|length) ] + '}' -}}
{# ---------------- STRING ---------------- #}
{%- elif ty == "string" -%}
{%- if schema.get("enum") -%}
{%- set ns = namespace(enum = []) -%}
{%- for en in schema['enum'] -%}
{%- set ns.enum = ns.enum + ['"' ~ en ~ '"'] -%}
{%- endfor -%}
{{- ns.enum | join(' | ') -}}
{%- elif schema.get("format", "none") in ['date-time', 'date'] -%}
{{- 'Date' -}}
{%- else -%}
{{- 'string' -}}
{%- endif -%}
{# ---------------- NUMBER / INTEGER ---------------- #}
{%- elif ty in ["number", "integer"] -%}
{%- if schema.get("enum") -%}
{{- schema.enum | join(' | ') -}}
{%- else -%}
{{- 'number' -}}
{%- endif -%}
{# ---------------- BOOLEAN ---------------- #}
{%- elif ty == "boolean" -%}
{{- 'boolean' -}}
{# ---------------- ARRAY ---------------- #}
{%- elif ty == "array" -%}
{%- if "items" in schema -%}
{{- json_schema_to_typescript(schema['items'], indent) + '[]' -}}
{%- else -%}
{{- 'Array<any>' -}}
{%- endif -%}
{# ---------------- FALLBACK ---------------- #}
{%- else -%}
{{- 'any' -}}
{%- endif -%}
{%- endmacro -%}
{#---------------------------------------------------------------
Renders a namespace and its tool definitions in TypeScript style
----------------------------------------------------------------#}
{%- macro render_tool_namespace(namespace_name, tools) -%}
{%- set ns = namespace(sections = ['namespace ' ~ namespace_name ~ ' {']) -%}
{%- for tool in tools -%}
{%- if tool.function -%}
{%- set tool = tool.function -%}
{%- endif -%}
{%- set ns_tool = namespace(content_lines=[]) -%}
{# ---------- TOOL DESCRIPTION ---------- #}
{%- if tool.get('description') -%}
{%- for line in tool['description'].split('\n') -%}
{%- if line.strip() -%}
{%- set ns_tool.content_lines = ns_tool.content_lines + ['// ' ~ line] -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{# ---------- TOOL SIGNATURE ---------- #}
{%- set main_body = "" -%}
{%- set params = tool.get("parameters") -%}
{%- if params and params.get("properties") -%}
{%- set param_type = json_schema_to_typescript(params, " ") -%}
{%- set main_body = 'type ' ~ tool.name ~ ' = (_: ' ~ param_type ~ ') => ' -%}
{%- else -%}
{%- set main_body = 'type ' ~ tool.name ~ ' = () => ' -%}
{%- endif -%}
{# ---------- RETURN TYPE ---------- #}
{%- set return_params = tool.get("return_parameters") -%}
{%- if return_params and return_params.get("properties") -%}
{%- set return_type = json_schema_to_typescript(return_params, " ") -%}
{%- set main_body = main_body ~ return_type -%}
{%- else -%}
{%- set main_body = main_body ~ 'any' -%}
{%- endif -%}
{%- set main_body = main_body ~ ';\n' -%}
{%- set ns_tool.content_lines = ns_tool.content_lines + [main_body] -%}
{# ---------- ADD TOOL TO SECTIONS ---------- #}
{%- set ns.sections = ns.sections + [ns_tool.content_lines | join('\n')] -%}
{%- endfor -%}
{%- set ns.sections = ns.sections + ['} // namespace ' ~ namespace_name] -%}
{{- ns.sections | join('\n') -}}
{%- endmacro -%}
{# ----------- MESSAGE RENDERING HELPER FUNCTIONS ------------ #}
{%- macro render_role_message(message, role=None) -%}
{%- if not role -%}
{%- set role = message["role"] -%}
{%- endif -%}
{%- set message_content = message['content'] or '' -%}
{%- if message_content is not string -%}
{%- set message_content = message_content | tojson(ensure_ascii=False) -%}
{%- endif -%}
{{- role + add_tokens.role_sep + message_content + add_tokens.message_sep -}}
{%- endmacro -%}
{%- macro render_function_call(message) -%}
{%- set call = message['content'] -%}
{%- if call.function -%}
{%- set call = call.function -%}
{%- endif -%}
{%- set arguments = call['arguments'] -%}
{%- if arguments is not string -%}
{%- set arguments = arguments| tojson(ensure_ascii=False) -%}
{%- endif -%}
{{- render_role_message(
{
'role': 'function call',
'content': '{"name": "' ~ call['name'] ~ '", "arguments": ' ~ arguments ~ '}'
}
) -}}
{%- endmacro -%}
{# ----- SPECIAL TOKENS ----- #}
{%- set add_tokens = namespace(
role_sep="<|role_sep|>\n",
message_sep="<|message_sep|>\n\n"
) -%}
{# ----- DEFAULT DEVSYSTEM ----- #}
{%- set DEVSYSTEM -%}
<role_description>
Description of the roles available in the dialog.
`developer system`
A message added by Sber before the main dialog. It has the highest priority and sets global, non-overridable conditions (for example, conversation rules, the safety policy, the assistant's overall response style, etc.).
`system`
A system instruction added by developers or by the user, but with a lower priority than `developer system`. It usually describes the assistant's instructions, a specific response style, and other conditions for this particular dialog.
`user`
A message or request from the user. The assistant follows it if it does not conflict with higher-priority instructions (see <instruction_priority>).
`user memory`
A sequence of the most up-to-date long-term facts about the user at the time of their request, presented as a JSON list of strings. Facts are listed in chronological order, meaning newer facts are appended to the end of the sequence. When facts are changed or deleted, records of previous facts remain in the sequence. The assistant saves facts using a function and uses them in accordance with the <memory_guidelines> block below.
`added files`
Metadata about files available for use in the dialog, presented in JSON format. It contains the following keys: id (a unique file identifier), name (file name), type (file type).
`assistant`
The assistant's reply to the user's request. If the system instruction or the user does not set additional rules for `assistant`, this reply must comply with the instructions in the <assistant_guidelines> block below. The list of functions available to call is contained in `function descriptions`. The name of the required function and its arguments will be generated next by the `function call` role. In its replies, the assistant follows the instructions in accordance with <instruction_priority>.
`function descriptions`
Function descriptions in TypeScript format. A function is a special tool (or a set of instructions) that the assistant can call to perform specific actions, computations, or obtain data needed to solve the user's task. Each function description contains blocks with the name, description, and arguments. Sometimes the description contains separate blocks with return parameters and usage examples that illustrate the correct call and arguments.
`function call`
The function that `assistant` calls based on the dialog context, and its arguments. The function is invoked in strict accordance with the instructions in the <function_usage> block.
`function result`
The result of the last function call.
</role_description>
<available_modalities>
The assistant can work with the following modalities: text, available functions.
</available_modalities>
<instruction_priority>
If instructions from different roles conflict within the dialog context, observe the following priorities:
`developer system` > `system` > `user` > `function descriptions` > `function result` > `user memory`
</instruction_priority>
<function_usage>
Basic instructions for working with functions.
Only call those functions that are described in `function descriptions`.
Call available functions when, according to their description, such a call will help provide a more complete and/or accurate answer to the user's request. Fill in function arguments using information from the dialog context. If a function could help answer the request but a required argument is missing from the context, ask the user for the missing data before calling the function. If a necessary function is unavailable or an error occurs, briefly inform the user and, if possible, suggest an alternative.
</function_usage>
<memory_guidelines>
Rules for using facts in long-term memory:
If there is no message under the `user memory` role in the dialog, this is equivalent to the absence of long-term facts about the user in memory. In that case, information about the user is limited to the current dialog, and no new facts should be saved.
</memory_guidelines>
<assistant_guidelines>
You are a helpful assistant.
# Instructions
- Strictly follow the instruction priority.
- Maintain a logical chain of reasoning when answering the user's question.
- For complex questions (for example, STEM), try to answer in detail unless the system message or dialog context limits the response length.
- Be helpful, truthful, and avoid unsafe or prohibited content in your responses.
- Try to reply in the language in which the user asked their question.
</assistant_guidelines>
A dialog will follow below.
The dialog may include various roles described in the <role_description> block.
Each turn begins with the role name and a special token that marks the end of the role's full name, and ends with a special end-of-turn token.
Your task is to continue the dialog from the last specified role in accordance with the dialog context.
{%- endset -%}
{#- ---------------------- RENDERING STARTS HERE ---------------------- -#}
{# ----- RENDER BOS TOKEN ----- #}
{{- bos_token -}}
{# ----- RENDER DEVSYSTEM ----- #}
{{- render_role_message({"role": "developer system", "content": DEVSYSTEM}) -}}
{# ----- RENDER SYSTEM IF PRESENT ----- #}
{%- if messages and messages[0]['role'] == 'system' -%}
{{- render_role_message(messages[0]) -}}
{%- set messages = messages[1:] -%}
{%- endif -%}
{# ----- RENDER TOOLS ----- #}
{%- if tools -%}
{%- set tools_content = (
render_tool_namespace('functions', tools)
+ "\n\n"
) -%}
{{- render_role_message({'role': 'function descriptions', 'content': tools_content}) -}}
{%- endif -%}
{# ----- MAIN MESSAGE LOOP ----- #}
{%- for message in messages -%}
{# ----- TOOL MESSAGE -------#}
{%- if message['role'] == 'tool' -%}
{{- render_role_message(message, 'function result') -}}
{# ----- ASSISTANT MESSAGE ----- #}
{%- elif message['role'] == 'assistant' -%}
{# ----- FUNCTION CALL PART CHECKING: SINGLE CALL SETUP ----- #}
{%- if message.tool_calls is defined and message.tool_calls -%}
{%- set function_call = message.tool_calls[0] -%}
{%- else -%}
{%- set function_call = None -%}
{%- endif -%}
{# ----- MAIN ASSISTANT RENDERING ----- #}
{{- render_role_message({'role': 'assistant', 'content': message.content}) -}}
{%- if function_call -%}
{{- render_function_call({'role': 'function call', 'content': function_call}) -}}
{%- endif -%}
{# ----- OTHER MESSAGES ----- #}
{%- else -%}
{{- render_role_message(message) -}}
{%- endif -%}
{# ----- ADDING GENERATION PROMPT ----- #}
{%- if loop.last and add_generation_prompt and message['role'] != 'assistant' -%}
{{- 'assistant' + add_tokens.role_sep -}}
{%- endif -%}
{%- endfor -%}
@@ -0,0 +1,339 @@
{#--------TOOL RENDERING FUNCTIONS---------#}
{#---------------------------------------------------------------
Converts JSON Schema (dict) to a TypeScript type definition
----------------------------------------------------------------#}
{%- macro json_schema_to_typescript(schema, indent="") -%}
{%- set ADDITIONAL_JSON_KEYS = ['format', 'maxItems', 'maximum', 'minItems', 'minimum', 'pattern'] -%}
{%- set ty = schema.get("type") -%}
{# ---------------- OBJECT ---------------- #}
{%- if ty == "object" -%}
{{- "{\n" -}}
{# Start building property list #}
{%- set props = schema.get("properties", {}) -%}
{%- set required = schema.get("required", []) -%}
{%- set has_additional_props = schema.get("additionalProperties") is defined -%}
{%- set additional_props_type = none -%}
{%- if has_additional_props -%}
{%- if schema.additionalProperties == true -%}
{%- set additional_props_type = {'type': 'any'} -%}
{%- elif schema.additionalProperties is mapping -%}
{%- set additional_props_type = schema.additionalProperties -%}
{%- endif -%}
{%- endif -%}
{%- for key, val in props.items() -%}
{# ---------- Description Comments ---------- #}
{%- if "description" in val -%}
{%- for line in val['description'].split('\n') -%}
{%- if line.strip() -%}
{{- indent + '// ' + line + '\n' -}}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{# ---------- Additional JSON Keys ---------- #}
{%- for add_key, add_val in val.items() -%}
{%- if add_key in ADDITIONAL_JSON_KEYS -%}
{%- if add_val is string -%}
{{- indent + '// ' + add_key + ': "' + add_val + '"' + '\n' -}}
{%- else -%}
{{- indent + '// ' + add_key + ': ' ~ add_val ~ '\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{# ---------- Property Definition ---------- #}
{%- set type_str = json_schema_to_typescript(
val,
indent + " "
) -%}
{{- indent + key + ('' if key in required else '?') + ': ' + type_str + ',' -}}
{%- if "default" in val or "defalut_value" in val -%}
{%- set default = val.get("default", val.get("defalut_value")) -%}
{%- if default is string -%}
{{- ' // default: "' + default + '"' -}}
{%- else -%}
{{- ' // default: ' ~ default -}}
{%- endif -%}
{%- endif -%}
{{- "\n" -}}
{%- endfor -%}
{# Handle additionalProperties as index signature #}
{%- if has_additional_props and additional_props_type is not none -%}
{%- set additional_type_str = json_schema_to_typescript(
additional_props_type,
indent + " "
) -%}
{{- indent + '[key: string]: ' + additional_type_str + '\n' -}}
{%- endif -%}
{{- indent[: (indent|length - " "|length) ] + '}' -}}
{# ---------------- STRING ---------------- #}
{%- elif ty == "string" -%}
{%- if schema.get("enum") -%}
{%- set ns = namespace(enum = []) -%}
{%- for en in schema['enum'] -%}
{%- set ns.enum = ns.enum + ['"' ~ en ~ '"'] -%}
{%- endfor -%}
{{- ns.enum | join(' | ') -}}
{%- elif schema.get("format", "none") in ['date-time', 'date'] -%}
{{- 'Date' -}}
{%- else -%}
{{- 'string' -}}
{%- endif -%}
{# ---------------- NUMBER / INTEGER ---------------- #}
{%- elif ty in ["number", "integer"] -%}
{%- if schema.get("enum") -%}
{{- schema.enum | join(' | ') -}}
{%- else -%}
{{- 'number' -}}
{%- endif -%}
{# ---------------- BOOLEAN ---------------- #}
{%- elif ty == "boolean" -%}
{{- 'boolean' -}}
{# ---------------- ARRAY ---------------- #}
{%- elif ty == "array" -%}
{%- if "items" in schema -%}
{{- json_schema_to_typescript(schema['items'], indent) + '[]' -}}
{%- else -%}
{{- 'Array<any>' -}}
{%- endif -%}
{# ---------------- FALLBACK ---------------- #}
{%- else -%}
{{- 'any' -}}
{%- endif -%}
{%- endmacro -%}
{#---------------------------------------------------------------
Renders a namespace and its tool definitions in TypeScript style
----------------------------------------------------------------#}
{%- macro render_tool_namespace(namespace_name, tools) -%}
{%- set ns = namespace(sections = ['namespace ' ~ namespace_name ~ ' {']) -%}
{%- for tool in tools -%}
{%- if tool.function -%}
{%- set tool = tool.function -%}
{%- endif -%}
{%- set ns_tool = namespace(content_lines=[]) -%}
{# ---------- TOOL DESCRIPTION ---------- #}
{%- if tool.get('description') -%}
{%- for line in tool['description'].split('\n') -%}
{%- if line.strip() -%}
{%- set ns_tool.content_lines = ns_tool.content_lines + ['// ' ~ line] -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{# ---------- TOOL SIGNATURE ---------- #}
{%- set main_body = "" -%}
{%- set params = tool.get("parameters") -%}
{%- if params and params.get("properties") -%}
{%- set param_type = json_schema_to_typescript(params, " ") -%}
{%- set main_body = 'type ' ~ tool.name ~ ' = (_: ' ~ param_type ~ ') => ' -%}
{%- else -%}
{%- set main_body = 'type ' ~ tool.name ~ ' = () => ' -%}
{%- endif -%}
{# ---------- RETURN TYPE ---------- #}
{%- set return_params = tool.get("return_parameters") -%}
{%- if return_params and return_params.get("properties") -%}
{%- set return_type = json_schema_to_typescript(return_params, " ") -%}
{%- set main_body = main_body ~ return_type -%}
{%- else -%}
{%- set main_body = main_body ~ 'any' -%}
{%- endif -%}
{%- set main_body = main_body ~ ';\n' -%}
{%- set ns_tool.content_lines = ns_tool.content_lines + [main_body] -%}
{# ---------- ADD TOOL TO SECTIONS ---------- #}
{%- set ns.sections = ns.sections + [ns_tool.content_lines | join('\n')] -%}
{%- endfor -%}
{%- set ns.sections = ns.sections + ['} // namespace ' ~ namespace_name] -%}
{{- ns.sections | join('\n') -}}
{%- endmacro -%}
{# ----------- MESSAGE RENDERING HELPER FUNCTIONS ------------ #}
{%- macro render_function_call(call) -%}
{%- if call.function -%}
{%- set call = call.function -%}
{%- endif -%}
{%- set arguments = call['arguments'] -%}
{%- if arguments is not string -%}
{%- set arguments = arguments| tojson(ensure_ascii=False) -%}
{%- endif -%}
{{- '{"name": "' ~ call['name'] ~ '", "arguments": ' ~ arguments ~ '}' -}}
{%- endmacro -%}
{%- macro render_role_message(message, role=None) -%}
{%- if not role -%}
{%- set role = message["role"] -%}
{%- endif -%}
{%- set message_content = message['content'] or '' -%}
{%- if message_content is not string -%}
{%- set message_content = message_content | tojson(ensure_ascii=False) -%}
{%- endif -%}
{{- role + add_tokens.role_sep + message_content -}}
{%- if message.tool_calls is defined and message.tool_calls -%}
{{- add_tokens.function_call + render_function_call(message.tool_calls[0]) -}}
{%- endif -%}
{{- add_tokens.message_sep -}}
{%- endmacro -%}
{# ----- SPECIAL TOKENS ----- #}
{%- set add_tokens = namespace(
role_sep="<|role_sep|>\n",
message_sep="<|message_sep|>\n\n",
function_call="<|function_call|>"
) -%}
{# ----- DEFAULT DEVSYSTEM ----- #}
{%- set DEVSYSTEM -%}
<role_description>
Description of the roles available in the dialog.
`developer system`
A message added by Sber before the main dialog. It has the highest priority and sets global, non-overridable conditions (for example, conversation rules, the safety policy, the assistant's overall response style, etc.).
`system`
A system instruction added by developers or by the user, but with a lower priority than `developer system`. It usually describes the assistant's instructions, a specific response style, and other conditions for this particular dialog.
`user`
A message or request from the user. The assistant follows it if it does not conflict with higher-priority instructions (see <instruction_priority>).
`user memory`
A sequence of the most up-to-date long-term facts about the user at the time of their request, presented as a JSON list of strings. Facts are listed in chronological order, meaning newer facts are appended to the end of the sequence. When facts are changed or deleted, records of previous facts remain in the sequence. The assistant saves facts using a function and uses them in accordance with the <memory_guidelines> block below.
`added files`
Metadata about files available for use in the dialog, presented in JSON format. It contains the following keys: id (a unique file identifier), name (file name), type (file type).
`assistant`
The assistant's reply to the user's request. If the system instruction or the user does not set additional rules for `assistant`, this reply must comply with the instructions in the <assistant_guidelines> block below. The list of functions available to call is contained in `function descriptions`. The name of the required function and its arguments will be generated next by the `function call` role. In its replies, the assistant follows the instructions in accordance with <instruction_priority>.
`function descriptions`
Function descriptions in TypeScript format. A function is a special tool (or a set of instructions) that the assistant can call to perform specific actions, computations, or obtain data needed to solve the user's task. Each function description contains blocks with the name, description, and arguments. Sometimes the description contains separate blocks with return parameters and usage examples that illustrate the correct call and arguments.
`function call`
The function that `assistant` calls based on the dialog context, and its arguments. The function is invoked in strict accordance with the instructions in the <function_usage> block.
`function result`
The result of the last function call.
</role_description>
<available_modalities>
The assistant can work with the following modalities: text, available functions.
</available_modalities>
<instruction_priority>
If instructions from different roles conflict within the dialog context, observe the following priorities:
`developer system` > `system` > `user` > `function descriptions` > `function result` > `user memory`
</instruction_priority>
<function_usage>
Basic instructions for working with functions.
Only call those functions that are described in `function descriptions`.
Call available functions when, according to their description, such a call will help provide a more complete and/or accurate answer to the user's request. Fill in function arguments using information from the dialog context. If a function could help answer the request but a required argument is missing from the context, ask the user for the missing data before calling the function. If a necessary function is unavailable or an error occurs, briefly inform the user and, if possible, suggest an alternative.
</function_usage>
<memory_guidelines>
Rules for using facts in long-term memory:
If there is no message under the `user memory` role in the dialog, this is equivalent to the absence of long-term facts about the user in memory. In that case, information about the user is limited to the current dialog, and no new facts should be saved.
</memory_guidelines>
<assistant_guidelines>
You are a helpful assistant.
# Instructions
- Strictly follow the instruction priority.
- Maintain a logical chain of reasoning when answering the user's question.
- For complex questions (for example, STEM), try to answer in detail unless the system message or dialog context limits the response length.
- Be helpful, truthful, and avoid unsafe or prohibited content in your responses.
- Try to reply in the language in which the user asked their question.
</assistant_guidelines>
A dialog will follow below.
The dialog may include various roles described in the <role_description> block.
Each turn begins with the role name and a special token that marks the end of the role's full name, and ends with a special end-of-turn token.
Your task is to continue the dialog from the last specified role in accordance with the dialog context.
{%- endset -%}
{#- ---------------------- RENDERING STARTS HERE ---------------------- -#}
{# ----- RENDER BOS TOKEN ----- #}
{{- bos_token -}}
{# ----- RENDER DEVSYSTEM ----- #}
{{- render_role_message({"role": "developer system", "content": DEVSYSTEM}) -}}
{# ----- RENDER SYSTEM IF PRESENT ----- #}
{%- if messages and messages[0]['role'] == 'system' -%}
{{- render_role_message(messages[0]) -}}
{%- set messages = messages[1:] -%}
{%- endif -%}
{# ----- RENDER TOOLS ----- #}
{%- if tools -%}
{%- set tools_content = (
render_tool_namespace('functions', tools)
+ "\n\n"
) -%}
{{- render_role_message({'role': 'function descriptions', 'content': tools_content}) -}}
{%- endif -%}
{# ----- MAIN MESSAGE LOOP ----- #}
{%- for message in messages -%}
{# ----- TOOL MESSAGE -------#}
{%- if message['role'] == 'tool' -%}
{{- render_role_message(message, 'function result') -}}
{# ----- OTHER MESSAGES ----- #}
{%- else -%}
{{- render_role_message(message) -}}
{%- endif -%}
{# ----- ADDING GENERATION PROMPT ----- #}
{%- if loop.last and add_generation_prompt and message['role'] != 'assistant' -%}
{{- 'assistant' + add_tokens.role_sep -}}
{%- endif -%}
{%- endfor -%}
+4
View File
@@ -293,6 +293,10 @@ class LlamaBenchData:
for t in self.repo.tags:
if t.name == name:
return t.commit.hexsha[:self.build_len]
for remote in self.repo.remotes:
for ref in remote.refs:
if ref.name == name or ref.remote_head == name:
return ref.commit.hexsha[:self.build_len]
for c in self.repo.iter_commits("--all"):
if c.hexsha[:self.build_len] == name[:self.build_len]:
return c.hexsha[:self.build_len]
+2 -2
View File
@@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "refs/tags/v0.35.0"
HTTPLIB_VERSION = "refs/tags/v0.37.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
@@ -15,7 +15,7 @@ vendor = {
# not using latest tag to avoid this issue: https://github.com/ggml-org/llama.cpp/pull/17179#discussion_r2515877926
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.24/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/13d161bc8d856ad61ae46b798bbeffc0f49808e8/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/9634bedb5b5a2ca38c1ee7108a9358a4e233f14d/miniaudio.h": "vendor/miniaudio/miniaudio.h",
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h",
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py",
+9
View File
@@ -185,6 +185,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" },
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
@@ -365,6 +366,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_LATENT_DOWN, "blk.%d.ffn_latent_down" },
{ LLM_TENSOR_FFN_LATENT_UP, "blk.%d.ffn_latent_up" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
@@ -1087,6 +1090,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_Q_NORM,
@@ -1878,6 +1882,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_FFN_LATENT_DOWN,
LLM_TENSOR_FFN_LATENT_UP,
// MoE shared expert layer
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
@@ -2753,6 +2759,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// Nemotron 3 Super
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
+3
View File
@@ -189,6 +189,7 @@ enum llm_kv {
LLM_KV_EXPERT_GROUP_SCALE,
LLM_KV_EXPERTS_PER_GROUP,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_MOE_LATENT_SIZE,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_NUM_DEEPSTACK_LAYERS,
LLM_KV_POOLING_TYPE,
@@ -385,6 +386,8 @@ enum llm_tensor {
LLM_TENSOR_FFN_GATE_CHEXPS,
LLM_TENSOR_FFN_UP_CHEXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_FFN_LATENT_DOWN,
LLM_TENSOR_FFN_LATENT_UP,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
+70 -25
View File
@@ -151,7 +151,8 @@ llama_context::llama_context(
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
cparams.fused_gdn_ar = true;
cparams.fused_gdn_ch = false; // TODO: implement
cparams.fused_gdn_ch = true;
cparams.auto_fgdn = true;
// with causal attention, the batch size is limited by the context size
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -462,37 +463,81 @@ void llama_context::sched_reserve() {
cparams.auto_fa = false;
}
if (cparams.fused_gdn_ar) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check");
}
if (cparams.auto_fgdn) {
LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
if (cparams.fused_gdn_ar) {
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
}
}
if (gdn_device_mismatch) {
cparams.fused_gdn_ar = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
} else {
LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
}
}
if (gdn_device_mismatch) {
cparams.fused_gdn_ar = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__);
if (cparams.fused_gdn_ch) {
// more than one token in the batch per sequence in order to take the chunked path
auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
bool gdn_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_GATED_DELTA_NET) {
continue;
}
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_gdn != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
gdn_device_mismatch = true;
break;
}
}
if (gdn_device_mismatch) {
cparams.fused_gdn_ch = false;
LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
} else {
LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
}
}
cparams.auto_fgdn = false;
}
// reserve worst-case graph
+1
View File
@@ -33,6 +33,7 @@ struct llama_cparams {
bool auto_fa;
bool fused_gdn_ar; // use fused gated delta net (autoregressive)
bool fused_gdn_ch; // use fused gated delta net (chunked)
bool auto_fgdn;
bool no_perf;
bool warmup;
bool op_offload;
+2 -2
View File
@@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence(
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
uint64_t min_times = std::stoull(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
uint64_t max_times = UINT64_MAX; // default: no max limit
@@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence(
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
max_times = std::stoull(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
+57 -6
View File
@@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
const bool last = (
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
);
for (int i = 0; i < n_tokens; ++i) {
@@ -900,7 +900,8 @@ ggml_tensor * llm_graph_context::build_cvec(
ggml_tensor * llm_graph_context::build_lora_mm(
ggml_tensor * w,
ggml_tensor * cur) const {
ggml_tensor * cur,
ggml_tensor * w_s) const {
ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
for (const auto & lora : *loras) {
@@ -921,6 +922,10 @@ ggml_tensor * llm_graph_context::build_lora_mm(
res = ggml_add(ctx0, res, ab_cur);
}
if (w_s) {
res = ggml_mul(ctx0, res, w_s);
}
return res;
}
@@ -1166,7 +1171,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in,
ggml_tensor * gate_up_exps) const {
ggml_tensor * gate_up_exps,
ggml_tensor * up_exps_s,
ggml_tensor * gate_exps_s,
ggml_tensor * down_exps_s) const {
return build_moe_ffn(
cur,
gate_inp, /* gate_inp_b */ nullptr,
@@ -1182,7 +1190,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
gating_op,
il,
probs_in,
gate_up_exps
gate_up_exps,
/* gate_up_exps_b */ nullptr,
up_exps_s,
gate_exps_s,
down_exps_s
);
}
@@ -1206,7 +1218,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
int il,
ggml_tensor * probs_in,
ggml_tensor * gate_up_exps,
ggml_tensor * gate_up_exps_b) const {
ggml_tensor * gate_up_exps_b,
ggml_tensor * up_exps_s,
ggml_tensor * gate_exps_s,
ggml_tensor * down_exps_s) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -1358,6 +1373,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(gate_up, "ffn_moe_gate_up_biased", il);
}
// apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
if (up_exps_s) {
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
gate_up = ggml_mul(ctx0, gate_up, s);
cb(gate_up, "ffn_moe_gate_up_scaled", il);
}
const int64_t n_ff = gate_up->ne[0] / 2;
cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
cb(cur, "ffn_moe_gate", il);
@@ -1373,6 +1397,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(up, "ffn_moe_up_biased", il);
}
// apply per-expert scale2 to up
if (up_exps_s) {
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
up = ggml_mul(ctx0, up, s);
cb(up, "ffn_moe_up_scaled", il);
}
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate", il);
@@ -1384,6 +1417,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
cb(cur, "ffn_moe_gate_biased", il);
}
// apply per-expert scale2 to gate
if (gate_exps_s) {
ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
cur = ggml_mul(ctx0, cur, s);
cb(cur, "ffn_moe_gate_scaled", il);
}
}
const bool has_gate = gate_exps || gate_up_exps;
@@ -1463,6 +1505,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(experts, "ffn_moe_down_biased", il);
}
// apply per-expert scale2 to down
if (down_exps_s) {
ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
experts = ggml_mul(ctx0, experts, s);
cb(experts, "ffn_moe_down_scaled", il);
}
if (!weight_before_ffn) {
experts = ggml_mul(ctx0, experts, weights);
cb(cur, "ffn_moe_weighted", il);
@@ -2552,7 +2603,7 @@ void llm_graph_context::build_pooling(
}
// softmax for qwen3 reranker
if (arch == LLM_ARCH_QWEN3) {
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
cur = ggml_soft_max(ctx0, cur);
}
} break;
+11 -4
View File
@@ -764,10 +764,11 @@ struct llm_graph_context {
ggml_tensor * cur,
int il) const;
// do mat_mul, while optionally apply lora
// do mat_mul, while optionally apply lora and per-tensor scale
ggml_tensor * build_lora_mm(
ggml_tensor * w,
ggml_tensor * cur) const;
ggml_tensor * cur,
ggml_tensor * w_s = nullptr) const;
// do mat_mul_id, while optionally apply lora
ggml_tensor * build_lora_mm_id(
@@ -814,7 +815,10 @@ struct llm_graph_context {
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in = nullptr,
ggml_tensor * gate_up_exps = nullptr) const;
ggml_tensor * gate_up_exps = nullptr,
ggml_tensor * up_exps_s = nullptr,
ggml_tensor * gate_exps_s = nullptr,
ggml_tensor * down_exps_s = nullptr) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
@@ -836,7 +840,10 @@ struct llm_graph_context {
int il,
ggml_tensor * probs_in = nullptr,
ggml_tensor * gate_up_exps = nullptr,
ggml_tensor * gate_up_exps_b = nullptr) const;
ggml_tensor * gate_up_exps_b = nullptr,
ggml_tensor * up_exps_s = nullptr,
ggml_tensor * gate_exps_s = nullptr,
ggml_tensor * down_exps_s = nullptr) const;
//
// inputs
+1
View File
@@ -89,6 +89,7 @@ struct llama_hparams {
bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;
uint32_t moe_latent_size = 0;
uint32_t nextn_predict_layers = 0;
float f_norm_eps;
+3 -3
View File
@@ -70,6 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t);
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__"
#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__"
#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__"
#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__"
+2
View File
@@ -42,6 +42,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
@@ -724,6 +725,7 @@ llama_model_loader::llama_model_loader(
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break;
default:
{
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
+58 -9
View File
@@ -135,6 +135,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_100B_A6B: return "100B.A6B";
case LLM_TYPE_102B_A12B: return "102B.A12B";
case LLM_TYPE_106B_A12B: return "106B.A12B";
case LLM_TYPE_120B_A12B: return "120B.A12B";
case LLM_TYPE_122B_A10B: return "122B.A10B";
case LLM_TYPE_196B_A11B: return "196B.A11B";
case LLM_TYPE_230B_A10B: return "230B.A10B";
@@ -1861,10 +1862,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false);
switch (hparams.n_layer) {
case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B
case 56: type = LLM_TYPE_9B; break;
case 88: type = LLM_TYPE_120B_A12B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
@@ -5007,23 +5010,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
} break;
case LLM_ARCH_T5:
@@ -5544,6 +5547,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_ssm_head = hparams.ssm_dt_rank;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd;
// embeddings
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -5603,8 +5607,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0);
// MoE branch
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0);
// Shared expert branch
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
@@ -7436,6 +7443,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
default:
throw std::runtime_error("unknown architecture");
}
// generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2)
// this avoids having to add scale loading to every architecture
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// attention weight scales (per-tensor, shape {1})
if (!layer.wq_s && layer.wq) {
layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.wk_s && layer.wk) {
layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.wv_s && layer.wv) {
layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.wo_s && layer.wo) {
layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
// dense FFN weight scales (per-tensor, shape {1})
if (!layer.ffn_gate_s && layer.ffn_gate) {
layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ffn_down_s && layer.ffn_down) {
layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
if (!layer.ffn_up_s && layer.ffn_up) {
layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
}
// MoE expert weight scales (per-expert, shape {n_expert})
if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) {
layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
}
if (!layer.ffn_down_exps_s && layer.ffn_down_exps) {
layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
}
if (!layer.ffn_up_exps_s && layer.ffn_up_exps) {
layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
}
}
}
ml.done_getting_tensors();
+17 -7
View File
@@ -126,6 +126,7 @@ enum llm_type {
LLM_TYPE_100B_A6B,
LLM_TYPE_102B_A12B, // Solar-Open
LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_120B_A12B, // Nemotron 3 Super
LLM_TYPE_122B_A10B, // Qwen3.5
LLM_TYPE_196B_A11B, // Step3.5-Flash
LLM_TYPE_230B_A10B, // Minimax M2
@@ -294,6 +295,15 @@ struct llama_layer {
struct ggml_tensor * ffn_up_exps_b = nullptr;
struct ggml_tensor * ffn_gate_up_exps_b = nullptr;
// ff MoE per-expert scales (NVFP4 per-tensor scale2)
struct ggml_tensor * ffn_gate_exps_s = nullptr;
struct ggml_tensor * ffn_down_exps_s = nullptr;
struct ggml_tensor * ffn_up_exps_s = nullptr;
// ff MoE latent proj
struct ggml_tensor * ffn_latent_down = nullptr;
struct ggml_tensor * ffn_latent_up = nullptr;
// ff shared expert (shexp)
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
struct ggml_tensor * ffn_gate_shexp = nullptr;
@@ -387,13 +397,13 @@ struct llama_layer {
struct ggml_tensor * rope_freqs = nullptr;
// bitnet scale
struct ggml_tensor * wq_scale = nullptr;
struct ggml_tensor * wk_scale = nullptr;
struct ggml_tensor * wv_scale = nullptr;
struct ggml_tensor * wo_scale = nullptr;
struct ggml_tensor * ffn_gate_scale = nullptr;
struct ggml_tensor * ffn_up_scale = nullptr;
struct ggml_tensor * ffn_down_scale = nullptr;
struct ggml_tensor * wq_s = nullptr;
struct ggml_tensor * wk_s = nullptr;
struct ggml_tensor * wv_s = nullptr;
struct ggml_tensor * wo_s = nullptr;
struct ggml_tensor * ffn_gate_s = nullptr;
struct ggml_tensor * ffn_up_s = nullptr;
struct ggml_tensor * ffn_down_s = nullptr;
// altup & laurel
struct ggml_tensor * per_layer_inp_gate = nullptr;
+16 -13
View File
@@ -870,9 +870,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
quantize_state_impl qs(model, params);
// these need to be set to n_layer by default
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
if (params->only_copy) {
ftype = ml.ftype;
}
@@ -979,6 +976,22 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// compute tensor metadata once and cache it
std::vector<tensor_metadata> metadata(tensors.size());
// initialize quantization state before preliminary loop (counters for use_more_bits)
{
for (size_t i = 0; i < tensors.size(); ++i) {
const auto cat = tensor_get_category(tensors[i]->tensor->name);
if (category_is_attn_v(cat)) {
++qs.n_attention_wv;
}
if (cat == tensor_category::OUTPUT) {
qs.has_tied_embeddings = false;
}
metadata[i].category = cat; // save and re-use the category while we're at it
}
// these also need to be set to n_layer by default
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
}
// flag for --dry-run
bool will_require_imatrix = false;
@@ -991,16 +1004,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const struct ggml_tensor * tensor = it->tensor;
const std::string name = ggml_get_name(tensor);
metadata[i].category = tensor_get_category(name);
if (category_is_attn_v(metadata[i].category)) {
++qs.n_attention_wv;
}
if (tensor_name_match_output_weight(name.c_str())) {
qs.has_tied_embeddings = false;
}
uint16_t i_split = params->keep_split ? it->idx : 0;
if (!ctx_outs[i_split]) {
ctx_outs[i_split].reset(gguf_init_empty());
+7 -22
View File
@@ -29,10 +29,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
if (model.layers[il].wq_scale) {
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
}
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
@@ -40,10 +37,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
}
// B1.K
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
if (model.layers[il].wk_scale) {
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
@@ -51,10 +45,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
}
// B1.V
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
if (model.layers[il].wv_scale) {
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -90,10 +81,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
LLM_NORM_RMS, il);
cb(cur, "attn_sub_norm", il);
cur = build_lora_mm(model.layers[il].wo, cur);
if (model.layers[il].wo_scale) {
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
}
cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
if (model.layers[il].bo) {
cur = ggml_add(ctx0, cur, model.layers[il].bo);
}
@@ -115,8 +103,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
NULL, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
@@ -127,10 +115,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
LLM_NORM_RMS, il);
cb(cur, "ffn_sub_norm", il);
cur = build_lora_mm(model.layers[il].ffn_down, cur);
if (model.layers[il].ffn_down_scale) {
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
}
cur = build_lora_mm(model.layers[il].ffn_down, cur, model.layers[il].ffn_down_s);
cb(cur, "ffn_down", il);
cur = ggml_add(ctx0, cur, ffn_inp);
+75 -27
View File
@@ -41,13 +41,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
if (cparams.fused_gdn_ch) {
//ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
//cb(result, LLAMA_TENSOR_NAME_FGDNCH, il);
GGML_ABORT("not implemented yet");
}
const float scale = 1.0f / sqrtf(S_k);
q = ggml_scale(ctx0, q, scale);
@@ -325,26 +318,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
if (cparams.fused_gdn_ar) {
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
cb(result, LLAMA_TENSOR_NAME_FGDNAR, il);
ggml_tensor * output = ggml_view_4d(ctx0, result,
S_v, H_v, n_tokens, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
ggml_tensor * new_state = ggml_view_4d(ctx0, result,
S_v, S_v, H_v, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * S_v),
ggml_row_size(result->type, S_v * S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
return {output, new_state};
}
const float scale = 1.0f / sqrtf(S_k);
q = ggml_scale(ctx0, q, scale);
@@ -401,3 +374,78 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
return {o, s};
}
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_fused(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * b,
ggml_tensor * s,
int il) {
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(S_k == S_v);
GGML_ASSERT(H_v % H_k == 0);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
if (n_tokens == 1) {
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
} else {
cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);
}
ggml_tensor * output = ggml_view_4d(ctx0, result,
S_v, H_v, n_tokens, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
ggml_tensor * new_state = ggml_view_4d(ctx0, result,
S_v, S_v, H_v, n_seqs,
ggml_row_size(result->type, S_v),
ggml_row_size(result->type, S_v * S_v),
ggml_row_size(result->type, S_v * S_v * H_v),
ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
return {output, new_state};
}
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * b,
ggml_tensor * s,
int il) {
const int64_t n_seq_tokens = q->ne[2];
if (n_seq_tokens == 1) {
if (cparams.fused_gdn_ar) {
return build_delta_net_fused(q, k, v, g, b, s, il);
}
return build_delta_net_autoregressive(q, k, v, g, b, s, il);
}
if (cparams.fused_gdn_ch) {
return build_delta_net_fused(q, k, v, g, b, s, il);
}
return build_delta_net_chunking(q, k, v, g, b, s, il);
}
+1 -3
View File
@@ -169,9 +169,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
// Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il);
auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il);
ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
ggml_tensor * new_state = attn_out.second;
+14 -7
View File
@@ -43,19 +43,19 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -91,6 +91,9 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
if (model.layers[il].wo_s) {
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
}
cb(cur, "attn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
@@ -109,9 +112,9 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
@@ -132,7 +135,11 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
LLM_FFN_SILU, true,
hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
il,
nullptr, nullptr,
model.layers[il].ffn_up_exps_s,
model.layers[il].ffn_gate_exps_s,
model.layers[il].ffn_down_exps_s);
cb(cur, "ffn_moe_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
+3 -2
View File
@@ -168,8 +168,9 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % (n_group*d_state) == 0);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % d_state == 0);
GGML_ASSERT(d_inner % n_group == 0);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+20
View File
@@ -44,6 +44,26 @@ struct llm_build_delta_net_base : public llm_graph_context {
ggml_tensor * b,
ggml_tensor * s,
int il);
// use the ggml_gated_delta_net fused operator
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * b,
ggml_tensor * s,
int il);
// choose one of two implementations above based on the number of tokens
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * b,
ggml_tensor * s,
int il);
};
struct llm_build_rwkv6_base : public llm_graph_context {

Some files were not shown because too many files have changed in this diff Show More