Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1257491047 | |||
| 083e18b11c | |||
| 3d94e967a1 | |||
| 7feb0a1005 | |||
| 0a8026e768 | |||
| 5ceed62421 | |||
| 7ca5991d2b | |||
| b3e3060f4e | |||
| 37adc9c6ba |
@@ -547,6 +547,46 @@ jobs:
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 3600
|
||||
|
||||
ubuntu-24-wasm-webgpu:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-latest-wasm-webgpu
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Install Emscripten
|
||||
run: |
|
||||
git clone https://github.com/emscripten-core/emsdk.git
|
||||
cd emsdk
|
||||
./emsdk install latest
|
||||
./emsdk activate latest
|
||||
|
||||
- name: Fetch emdawnwebgpu
|
||||
run: |
|
||||
DAWN_TAG="v20251027.212519"
|
||||
EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip"
|
||||
echo "Downloading ${EMDAWN_PKG}"
|
||||
curl -L -o emdawn.zip \
|
||||
"https://github.com/google/dawn/releases/download/${DAWN_TAG}/${EMDAWN_PKG}"
|
||||
unzip emdawn.zip
|
||||
|
||||
- name: Build WASM WebGPU
|
||||
run: |
|
||||
source emsdk/emsdk_env.sh
|
||||
emcmake cmake -B build-wasm \
|
||||
-DGGML_WEBGPU=ON \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DEMDAWNWEBGPU_DIR=emdawnwebgpu_pkg
|
||||
|
||||
cmake --build build-wasm --target test-backend-ops -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
|
||||
@@ -728,58 +728,6 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
|
||||
|
||||
openEuler-cann:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake libcurl-devel
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=ascend${{ matrix.chip_type }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
zip -y -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/*
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts (zip)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
|
||||
- name: Upload artifacts (tar)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
|
||||
|
||||
release:
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
|
||||
@@ -801,7 +749,6 @@ jobs:
|
||||
- macOS-arm64
|
||||
- macOS-x64
|
||||
- ios-xcode-build
|
||||
- openEuler-cann
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -869,6 +816,12 @@ jobs:
|
||||
> [!WARNING]
|
||||
> **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts.
|
||||
|
||||
<details open>
|
||||
|
||||
${{ github.event.head_commit.message }}
|
||||
|
||||
</details>
|
||||
|
||||
**macOS/iOS:**
|
||||
- [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz)
|
||||
- [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz)
|
||||
@@ -887,18 +840,6 @@ jobs:
|
||||
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
|
||||
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)
|
||||
|
||||
**openEuler:**
|
||||
- [openEuler x86 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-x86.tar.gz)
|
||||
- [openEuler x86 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-x86.tar.gz)
|
||||
- [openEuler aarch64 (310p)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-310p-openEuler-aarch64.tar.gz)
|
||||
- [openEuler aarch64 (910b)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-910b-openEuler-aarch64.tar.gz)
|
||||
|
||||
<details>
|
||||
|
||||
${{ github.event.head_commit.message }}
|
||||
|
||||
</details>
|
||||
|
||||
- name: Upload release
|
||||
id: upload_release
|
||||
uses: actions/github-script@v3
|
||||
|
||||
@@ -134,3 +134,5 @@ poetry.toml
|
||||
# IDE
|
||||
/*.code-workspace
|
||||
/.windsurf/
|
||||
# emscripten
|
||||
a.out.*
|
||||
|
||||
+15
-1
@@ -33,10 +33,24 @@ endif()
|
||||
|
||||
option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)
|
||||
|
||||
option(LLAMA_WASM_MEM64 "llama: use 64-bit memory in WASM builds" ON)
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
|
||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
|
||||
# Use 64-bit memory to support backend_get_memory queries
|
||||
# TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using
|
||||
if (LLAMA_WASM_MEM64)
|
||||
add_compile_options("-sMEMORY64=1")
|
||||
add_link_options("-sMEMORY64=1")
|
||||
endif()
|
||||
add_link_options("-sALLOW_MEMORY_GROWTH=1")
|
||||
|
||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF)
|
||||
option(LLAMA_BUILD_HTML "llama: build HTML file" ON)
|
||||
if (LLAMA_BUILD_HTML)
|
||||
set(CMAKE_EXECUTABLE_SUFFIX ".html")
|
||||
endif()
|
||||
else()
|
||||
if (MINGW)
|
||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||
|
||||
@@ -10,13 +10,16 @@
|
||||
/common/arg.* @ggerganov
|
||||
/common/base64.hpp.* @ggerganov
|
||||
/common/build-info.* @ggerganov
|
||||
/common/chat-peg-parser.* @aldehir
|
||||
/common/common.* @ggerganov
|
||||
/common/console.* @ggerganov
|
||||
/common/http.* @angt
|
||||
/common/llguidance.* @ggerganov
|
||||
/common/log.* @ggerganov
|
||||
/common/peg-parser.* @aldehir
|
||||
/common/sampling.* @ggerganov
|
||||
/common/speculative.* @ggerganov
|
||||
/common/unicode.* @aldehir
|
||||
/convert_*.py @CISC
|
||||
/examples/batched.swift/ @ggerganov
|
||||
/examples/batched/ @ggerganov
|
||||
|
||||
@@ -52,6 +52,8 @@ add_library(${TARGET} STATIC
|
||||
chat-parser.h
|
||||
chat-parser-xml-toolcall.h
|
||||
chat-parser-xml-toolcall.cpp
|
||||
chat-peg-parser.cpp
|
||||
chat-peg-parser.h
|
||||
chat.cpp
|
||||
chat.h
|
||||
common.cpp
|
||||
@@ -69,12 +71,16 @@ add_library(${TARGET} STATIC
|
||||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
regex-partial.cpp
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
speculative.h
|
||||
unicode.cpp
|
||||
unicode.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include <thread> // for hardware_concurrency
|
||||
#include <vector>
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
#include <linux/limits.h>
|
||||
#elif defined(_WIN32)
|
||||
@@ -41,6 +42,8 @@
|
||||
#else
|
||||
#include <sys/syslimits.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "chat-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "peg-parser.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -1483,6 +1485,11 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
return common_chat_peg_parse(syntax.parser, input, is_partial, syntax);
|
||||
}
|
||||
common_chat_msg_parser builder(input, is_partial, syntax);
|
||||
try {
|
||||
common_chat_parse(builder);
|
||||
@@ -1500,3 +1507,36 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
if (parser.empty()) {
|
||||
throw std::runtime_error("Failed to parse due to missing parser definition.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str());
|
||||
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
auto result = parser.parse(ctx);
|
||||
if (result.fail()) {
|
||||
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) {
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
} else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
auto mapper = common_chat_peg_constructed_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
} else {
|
||||
// Generic mapper
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
}
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
#include "chat-peg-parser.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
static std::string_view trim_trailing_space(std::string_view sv) {
|
||||
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
|
||||
sv.remove_suffix(1);
|
||||
}
|
||||
return sv;
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||
arena.visit(result, [this](const common_peg_ast_node & node) {
|
||||
map(node);
|
||||
});
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
|
||||
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
|
||||
|
||||
if (is_reasoning) {
|
||||
result.reasoning_content = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_content) {
|
||||
result.content = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
|
||||
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
|
||||
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
}
|
||||
|
||||
if (is_tool_id && current_tool) {
|
||||
current_tool->id = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_name && current_tool) {
|
||||
current_tool->name = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_args && current_tool) {
|
||||
current_tool->arguments = std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
}
|
||||
|
||||
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
|
||||
common_chat_peg_mapper::map(node);
|
||||
|
||||
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
|
||||
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
|
||||
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
|
||||
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
|
||||
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
|
||||
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
|
||||
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
|
||||
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
|
||||
|
||||
if (is_tool_open) {
|
||||
result.tool_calls.emplace_back();
|
||||
current_tool = &result.tool_calls.back();
|
||||
arg_count = 0;
|
||||
}
|
||||
|
||||
if (is_tool_name) {
|
||||
current_tool->name = std::string(node.text);
|
||||
current_tool->arguments = "{";
|
||||
}
|
||||
|
||||
if (is_arg_open) {
|
||||
needs_closing_quote = false;
|
||||
}
|
||||
|
||||
if (is_arg_name && current_tool) {
|
||||
if (arg_count > 0) {
|
||||
current_tool->arguments += ",";
|
||||
}
|
||||
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
|
||||
++arg_count;
|
||||
}
|
||||
|
||||
if (is_arg_string && current_tool) {
|
||||
// Serialize to JSON, but exclude the end quote
|
||||
std::string dumped = json(node.text).dump();
|
||||
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
|
||||
needs_closing_quote = true;
|
||||
}
|
||||
|
||||
if (is_arg_close && current_tool) {
|
||||
if (needs_closing_quote) {
|
||||
current_tool->arguments += "\"";
|
||||
}
|
||||
}
|
||||
|
||||
if (is_arg_json && current_tool) {
|
||||
current_tool->arguments += std::string(trim_trailing_space(node.text));
|
||||
}
|
||||
|
||||
if (is_tool_close && current_tool) {
|
||||
current_tool->arguments += "}";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
class common_chat_peg_builder : public common_peg_parser_builder {
|
||||
public:
|
||||
static constexpr const char * REASONING_BLOCK = "reasoning-block";
|
||||
static constexpr const char * REASONING = "reasoning";
|
||||
static constexpr const char * CONTENT = "content";
|
||||
|
||||
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
|
||||
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
|
||||
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
|
||||
common_chat_peg_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class common_chat_peg_mapper {
|
||||
public:
|
||||
common_chat_msg & result;
|
||||
|
||||
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
|
||||
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
virtual void map(const common_peg_ast_node & node);
|
||||
};
|
||||
|
||||
class common_chat_peg_native_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_ID = "tool-id";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARGS = "tool-args";
|
||||
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
|
||||
};
|
||||
|
||||
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
|
||||
public:
|
||||
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
|
||||
common_chat_peg_native_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
|
||||
public:
|
||||
static constexpr const char * TOOL = "tool";
|
||||
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||
static constexpr const char * TOOL_NAME = "tool-name";
|
||||
static constexpr const char * TOOL_ARG = "tool-arg";
|
||||
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
|
||||
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
|
||||
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
|
||||
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
|
||||
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
|
||||
|
||||
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
|
||||
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
|
||||
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
|
||||
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
|
||||
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
|
||||
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
|
||||
};
|
||||
|
||||
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
|
||||
common_chat_tool_call * current_tool;
|
||||
int arg_count = 0;
|
||||
bool needs_closing_quote = false;
|
||||
|
||||
public:
|
||||
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
|
||||
void map(const common_peg_ast_node & node) override;
|
||||
};
|
||||
|
||||
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
|
||||
common_chat_peg_constructed_builder builder;
|
||||
builder.set_root(fn(builder));
|
||||
return builder.build();
|
||||
}
|
||||
@@ -649,6 +649,9 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
|
||||
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
|
||||
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "peg-parser.h"
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
@@ -124,6 +125,11 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
|
||||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -154,6 +160,7 @@ struct common_chat_params {
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
@@ -163,6 +170,7 @@ struct common_chat_syntax {
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
common_peg_arena parser = {};
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
@@ -206,6 +214,7 @@ const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
|
||||
@@ -902,6 +902,8 @@ std::string fs_get_cache_directory() {
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#elif defined(__EMSCRIPTEN__)
|
||||
GGML_ABORT("not implemented on this platform");
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "http.h"
|
||||
#endif
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
#include <linux/limits.h>
|
||||
#elif defined(_WIN32)
|
||||
@@ -35,6 +36,8 @@
|
||||
#else
|
||||
#include <sys/syslimits.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||
|
||||
// isatty
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,459 @@
|
||||
#pragma once
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <variant>
|
||||
|
||||
struct common_grammar_builder;
|
||||
|
||||
class common_peg_parser_builder;
|
||||
|
||||
using common_peg_parser_id = size_t;
|
||||
constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast<common_peg_parser_id>(-1);
|
||||
|
||||
using common_peg_ast_id = size_t;
|
||||
constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast<common_peg_ast_id>(-1);
|
||||
|
||||
// Lightweight wrapper around common_peg_parser_id for convenience
|
||||
class common_peg_parser {
|
||||
common_peg_parser_id id_;
|
||||
common_peg_parser_builder & builder_;
|
||||
|
||||
public:
|
||||
common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {}
|
||||
common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {}
|
||||
|
||||
common_peg_parser & operator=(const common_peg_parser & other);
|
||||
common_peg_parser & operator+=(const common_peg_parser & other);
|
||||
common_peg_parser & operator|=(const common_peg_parser & other);
|
||||
|
||||
operator common_peg_parser_id() const { return id_; }
|
||||
common_peg_parser_id id() const { return id_; }
|
||||
|
||||
common_peg_parser_builder & builder() const { return builder_; }
|
||||
|
||||
// Creates a sequence
|
||||
common_peg_parser operator+(const common_peg_parser & other) const;
|
||||
|
||||
// Creates a sequence separated by spaces.
|
||||
common_peg_parser operator<<(const common_peg_parser & other) const;
|
||||
|
||||
// Creates a choice
|
||||
common_peg_parser operator|(const common_peg_parser & other) const;
|
||||
|
||||
common_peg_parser operator+(const char * str) const;
|
||||
common_peg_parser operator+(const std::string & str) const;
|
||||
common_peg_parser operator<<(const char * str) const;
|
||||
common_peg_parser operator<<(const std::string & str) const;
|
||||
common_peg_parser operator|(const char * str) const;
|
||||
common_peg_parser operator|(const std::string & str) const;
|
||||
};
|
||||
|
||||
common_peg_parser operator+(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator+(const std::string & str, const common_peg_parser & p);
|
||||
common_peg_parser operator<<(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator<<(const std::string & str, const common_peg_parser & p);
|
||||
common_peg_parser operator|(const char * str, const common_peg_parser & p);
|
||||
common_peg_parser operator|(const std::string & str, const common_peg_parser & p);
|
||||
|
||||
enum common_peg_parse_result_type {
|
||||
COMMON_PEG_PARSE_RESULT_FAIL = 0,
|
||||
COMMON_PEG_PARSE_RESULT_SUCCESS = 1,
|
||||
COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2,
|
||||
};
|
||||
|
||||
const char * common_peg_parse_result_type_name(common_peg_parse_result_type type);
|
||||
|
||||
struct common_peg_ast_node {
|
||||
common_peg_ast_id id;
|
||||
std::string rule;
|
||||
std::string tag;
|
||||
size_t start;
|
||||
size_t end;
|
||||
std::string_view text;
|
||||
std::vector<common_peg_ast_id> children;
|
||||
|
||||
bool is_partial = false;
|
||||
};
|
||||
|
||||
struct common_peg_parse_result;
|
||||
|
||||
using common_peg_ast_visitor = std::function<void(const common_peg_ast_node & node)>;
|
||||
|
||||
class common_peg_ast_arena {
|
||||
std::vector<common_peg_ast_node> nodes_;
|
||||
public:
|
||||
common_peg_ast_id add_node(
|
||||
const std::string & rule,
|
||||
const std::string & tag,
|
||||
size_t start,
|
||||
size_t end,
|
||||
std::string_view text,
|
||||
std::vector<common_peg_ast_id> children,
|
||||
bool is_partial = false
|
||||
) {
|
||||
common_peg_ast_id id = nodes_.size();
|
||||
nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial});
|
||||
return id;
|
||||
}
|
||||
|
||||
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
void clear() { nodes_.clear(); }
|
||||
|
||||
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
|
||||
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
|
||||
};
|
||||
|
||||
struct common_peg_parse_result {
|
||||
common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL;
|
||||
size_t start = 0;
|
||||
size_t end = 0;
|
||||
|
||||
std::vector<common_peg_ast_id> nodes;
|
||||
|
||||
common_peg_parse_result() = default;
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start)
|
||||
: type(type), start(start), end(start) {}
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end)
|
||||
: type(type), start(start), end(end) {}
|
||||
|
||||
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector<common_peg_ast_id> nodes)
|
||||
: type(type), start(start), end(end), nodes(std::move(nodes)) {}
|
||||
|
||||
bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; }
|
||||
bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; }
|
||||
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
|
||||
};
|
||||
|
||||
struct common_peg_parse_context {
|
||||
std::string input;
|
||||
bool is_partial;
|
||||
common_peg_ast_arena ast;
|
||||
|
||||
int parse_depth;
|
||||
|
||||
common_peg_parse_context()
|
||||
: is_partial(false), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input)
|
||||
: input(input), is_partial(false), parse_depth(0) {}
|
||||
|
||||
common_peg_parse_context(const std::string & input, bool is_partial)
|
||||
: input(input), is_partial(is_partial), parse_depth(0) {}
|
||||
};
|
||||
|
||||
class common_peg_arena;
|
||||
|
||||
// Parser variants
|
||||
struct common_peg_epsilon_parser {};
|
||||
|
||||
struct common_peg_start_parser {};
|
||||
|
||||
struct common_peg_end_parser {};
|
||||
|
||||
struct common_peg_literal_parser {
|
||||
std::string literal;
|
||||
};
|
||||
|
||||
struct common_peg_sequence_parser {
|
||||
std::vector<common_peg_parser_id> children;
|
||||
};
|
||||
|
||||
struct common_peg_choice_parser {
|
||||
std::vector<common_peg_parser_id> children;
|
||||
};
|
||||
|
||||
struct common_peg_repetition_parser {
|
||||
common_peg_parser_id child;
|
||||
int min_count;
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_and_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_not_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_any_parser {};
|
||||
|
||||
struct common_peg_space_parser {};
|
||||
|
||||
struct common_peg_chars_parser {
|
||||
struct char_range {
|
||||
uint32_t start;
|
||||
uint32_t end;
|
||||
bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; }
|
||||
};
|
||||
|
||||
std::string pattern;
|
||||
std::vector<char_range> ranges;
|
||||
bool negated;
|
||||
int min_count;
|
||||
int max_count; // -1 for unbounded
|
||||
};
|
||||
|
||||
struct common_peg_json_string_parser {};
|
||||
|
||||
struct common_peg_until_parser {
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
struct common_peg_schema_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string name;
|
||||
std::shared_ptr<nlohmann::ordered_json> schema;
|
||||
|
||||
// Indicates if the GBNF should accept a raw string that matches the schema.
|
||||
bool raw;
|
||||
};
|
||||
|
||||
struct common_peg_rule_parser {
|
||||
std::string name;
|
||||
common_peg_parser_id child;
|
||||
bool trigger;
|
||||
};
|
||||
|
||||
struct common_peg_ref_parser {
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct common_peg_atomic_parser {
|
||||
common_peg_parser_id child;
|
||||
};
|
||||
|
||||
struct common_peg_tag_parser {
|
||||
common_peg_parser_id child;
|
||||
std::string tag;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
common_peg_start_parser,
|
||||
common_peg_end_parser,
|
||||
common_peg_literal_parser,
|
||||
common_peg_sequence_parser,
|
||||
common_peg_choice_parser,
|
||||
common_peg_repetition_parser,
|
||||
common_peg_and_parser,
|
||||
common_peg_not_parser,
|
||||
common_peg_any_parser,
|
||||
common_peg_space_parser,
|
||||
common_peg_chars_parser,
|
||||
common_peg_json_string_parser,
|
||||
common_peg_until_parser,
|
||||
common_peg_schema_parser,
|
||||
common_peg_rule_parser,
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
std::vector<common_peg_parser_variant> parsers_;
|
||||
std::unordered_map<std::string, common_peg_parser_id> rules_;
|
||||
common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID;
|
||||
|
||||
public:
|
||||
const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); }
|
||||
common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); }
|
||||
|
||||
size_t size() const { return parsers_.size(); }
|
||||
bool empty() const { return parsers_.empty(); }
|
||||
|
||||
common_peg_parser_id get_rule(const std::string & name) const;
|
||||
bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); }
|
||||
|
||||
common_peg_parser_id root() const { return root_; }
|
||||
void set_root(common_peg_parser_id id) { root_ = id; }
|
||||
|
||||
common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const;
|
||||
common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const;
|
||||
|
||||
void resolve_refs();
|
||||
|
||||
void build_grammar(const common_grammar_builder & builder, bool lazy = false) const;
|
||||
|
||||
std::string dump(common_peg_parser_id id) const;
|
||||
|
||||
nlohmann::json to_json() const;
|
||||
static common_peg_arena from_json(const nlohmann::json & j);
|
||||
|
||||
std::string save() const;
|
||||
void load(const std::string & data);
|
||||
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
|
||||
common_peg_parser_id resolve_ref(common_peg_parser_id id);
|
||||
};
|
||||
|
||||
class common_peg_parser_builder {
|
||||
common_peg_arena arena_;
|
||||
|
||||
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
||||
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
||||
|
||||
public:
|
||||
common_peg_parser_builder();
|
||||
|
||||
// Match nothing, always succeed.
|
||||
// S -> ε
|
||||
common_peg_parser eps() { return add(common_peg_epsilon_parser{}); }
|
||||
|
||||
// Matches the start of the input.
|
||||
// S -> ^
|
||||
common_peg_parser start() { return add(common_peg_start_parser{}); }
|
||||
|
||||
// Matches the end of the input.
|
||||
// S -> $
|
||||
common_peg_parser end() { return add(common_peg_end_parser{}); }
|
||||
|
||||
// Matches an exact literal string.
|
||||
// S -> "hello"
|
||||
common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); }
|
||||
|
||||
// Matches a sequence of parsers in order, all must succeed.
|
||||
// S -> A B C
|
||||
common_peg_parser sequence() { return add(common_peg_sequence_parser{}); }
|
||||
common_peg_parser sequence(const std::vector<common_peg_parser_id> & parsers);
|
||||
common_peg_parser sequence(const std::vector<common_peg_parser> & parsers);
|
||||
common_peg_parser sequence(std::initializer_list<common_peg_parser> parsers);
|
||||
|
||||
// Matches the first parser that succeeds from a list of alternatives.
|
||||
// S -> A | B | C
|
||||
common_peg_parser choice() { return add(common_peg_choice_parser{}); }
|
||||
common_peg_parser choice(const std::vector<common_peg_parser_id> & parsers);
|
||||
common_peg_parser choice(const std::vector<common_peg_parser> & parsers);
|
||||
common_peg_parser choice(std::initializer_list<common_peg_parser> parsers);
|
||||
|
||||
// Matches one or more repetitions of a parser.
|
||||
// S -> A+
|
||||
common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); }
|
||||
|
||||
// Matches zero or more repetitions of a parser, always succeeds.
|
||||
// S -> A*
|
||||
common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); }
|
||||
|
||||
// Matches zero or one occurrence of a parser, always succeeds.
|
||||
// S -> A?
|
||||
common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); }
|
||||
|
||||
// Positive lookahead: succeeds if child parser succeeds, consumes no input.
|
||||
// S -> &A
|
||||
common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); }
|
||||
|
||||
// Negative lookahead: succeeds if child parser fails, consumes no input.
|
||||
// S -> !A
|
||||
common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); }
|
||||
|
||||
// Matches any single character.
|
||||
// S -> .
|
||||
common_peg_parser any() { return add(common_peg_any_parser{}); }
|
||||
|
||||
// Matches between min and max repetitions of characters from a character class.
|
||||
// S -> [a-z]{m,n}
|
||||
//
|
||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||
common_peg_parser chars(const std::string & classes, int min = 1, int max = -1);
|
||||
|
||||
// Creates a lightweight reference to a named rule (resolved during build()).
|
||||
// Use this for forward references in recursive grammars.
|
||||
// expr_ref -> expr
|
||||
common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); }
|
||||
|
||||
// Matches zero or more whitespace characters (space, tab, newline).
|
||||
// S -> [ \t\n]*
|
||||
common_peg_parser space() { return add(common_peg_space_parser{}); }
|
||||
|
||||
// Matches all characters until a delimiter is found (delimiter not consumed).
|
||||
// S -> (!delim .)*
|
||||
common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); }
|
||||
|
||||
// Matches all characters until one of the delimiters in the list is found (delimiter not consumed).
|
||||
// S -> (!delim .)*
|
||||
common_peg_parser until_one_of(const std::vector<std::string> & delimiters) { return add(common_peg_until_parser{delimiters}); }
|
||||
|
||||
// Matches everything
|
||||
// S -> .*
|
||||
common_peg_parser rest() { return until_one_of({}); }
|
||||
|
||||
// Matches between min and max repetitions of a parser (inclusive).
|
||||
// S -> A{m,n}
|
||||
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||
common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); }
|
||||
|
||||
// Matches exactly n repetitions of a parser.
|
||||
// S -> A{n}
|
||||
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
|
||||
|
||||
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
||||
// value -> object | array | string | number | true | false | null
|
||||
common_peg_parser json();
|
||||
common_peg_parser json_object();
|
||||
common_peg_parser json_string();
|
||||
common_peg_parser json_array();
|
||||
common_peg_parser json_number();
|
||||
common_peg_parser json_bool();
|
||||
common_peg_parser json_null();
|
||||
|
||||
// Matches JSON string content without the surrounding quotes.
|
||||
// Useful for extracting content within a JSON string.
|
||||
common_peg_parser json_string_content();
|
||||
|
||||
// Matches a JSON object member with a key and associated parser as the
|
||||
// value.
|
||||
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
|
||||
|
||||
// Wraps a parser with JSON schema metadata for grammar generation.
|
||||
// Used internally to convert JSON schemas to GBNF grammar rules.
|
||||
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);
|
||||
|
||||
// Creates a named rule, stores it in the grammar, and returns a ref.
|
||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||
// auto json = p.rule("json", json_obj | json_arr | ...)
|
||||
common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false);
|
||||
|
||||
// Creates a named rule using a builder function, and returns a ref.
|
||||
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||
// auto json = p.rule("json", [&]() { return json_object() | json_array() | ... })
|
||||
common_peg_parser rule(const std::string & name, const std::function<common_peg_parser()> & builder, bool trigger = false);
|
||||
|
||||
// Creates a trigger rule. When generating a lazy grammar from the parser,
|
||||
// only trigger rules and descendents are emitted.
|
||||
common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); }
|
||||
common_peg_parser trigger_rule(const std::string & name, const std::function<common_peg_parser()> & builder) { return rule(name, builder, true); }
|
||||
|
||||
// Creates an atomic parser. Atomic parsers do not create an AST node if
|
||||
// the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is
|
||||
// intended for situations where partial output is undesirable.
|
||||
common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); }
|
||||
|
||||
// Tags create nodes in the generated AST for semantic purposes.
|
||||
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
};
|
||||
|
||||
// Helper function for building parsers
|
||||
common_peg_arena build_peg_parser(const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);
|
||||
@@ -0,0 +1,64 @@
|
||||
#include "unicode.h"
|
||||
|
||||
// implementation adopted from src/unicode.cpp
|
||||
|
||||
size_t utf8_sequence_length(unsigned char first_byte) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
|
||||
return lookup[highbits];
|
||||
}
|
||||
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
|
||||
if (offset >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
|
||||
// ASCII fast path
|
||||
if (!(input[offset] & 0x80)) {
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1);
|
||||
}
|
||||
|
||||
// Invalid: continuation byte as first byte
|
||||
if (!(input[offset] & 0x40)) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
|
||||
// 2-byte sequence
|
||||
if (!(input[offset] & 0x20)) {
|
||||
if (offset + 1 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2);
|
||||
}
|
||||
|
||||
// 3-byte sequence
|
||||
if (!(input[offset] & 0x10)) {
|
||||
if (offset + 2 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3);
|
||||
}
|
||||
|
||||
// 4-byte sequence
|
||||
if (!(input[offset] & 0x08)) {
|
||||
if (offset + 3 >= input.size()) {
|
||||
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||
}
|
||||
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) {
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f);
|
||||
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4);
|
||||
}
|
||||
|
||||
// Invalid first byte
|
||||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <string_view>
|
||||
|
||||
// UTF-8 parsing utilities for streaming-aware unicode support
|
||||
|
||||
struct utf8_parse_result {
|
||||
uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS)
|
||||
size_t bytes_consumed; // How many bytes this codepoint uses (1-4)
|
||||
enum status { SUCCESS, INCOMPLETE, INVALID } status;
|
||||
|
||||
utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0)
|
||||
: codepoint(cp), bytes_consumed(bytes), status(s) {}
|
||||
};
|
||||
|
||||
// Determine the expected length of a UTF-8 sequence from its first byte
|
||||
// Returns 0 for invalid first bytes
|
||||
size_t utf8_sequence_length(unsigned char first_byte);
|
||||
|
||||
// Parse a single UTF-8 codepoint from input
|
||||
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||
@@ -0,0 +1,288 @@
|
||||
# Parsing Model Output
|
||||
|
||||
The `common` library contains a PEG parser implementation suitable for parsing
|
||||
model output.
|
||||
|
||||
Types with the prefix `common_peg_*` are intended for general use and may have
|
||||
applications beyond parsing model output, such as parsing user-provided regex
|
||||
patterns.
|
||||
|
||||
Types with the prefix `common_chat_peg_*` are specialized helpers for model
|
||||
output.
|
||||
|
||||
The parser features:
|
||||
|
||||
- Partial parsing of streaming input
|
||||
- Built-in JSON parsers
|
||||
- AST generation with semantics via "tagged" nodes
|
||||
|
||||
## Example
|
||||
|
||||
Below is a contrived example demonstrating how to use the PEG parser to parse
|
||||
output from a model that emits arguments as JSON.
|
||||
|
||||
```cpp
|
||||
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
// Build a choice of all available tools
|
||||
auto tool_choice = p.choice();
|
||||
for (const auto & tool : tools) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
auto tool_name = p.json_member("name", "\"" + p.literal(name) + "\"");
|
||||
auto tool_args = p.json_member("arguments", p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, "{" << tool_name << "," << tool_args << "}");
|
||||
}
|
||||
|
||||
// Define the tool call structure: <tool_call>[{tool}]</tool_call>
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
p.sequence({
|
||||
p.literal("<tool_call>["),
|
||||
tool_choice,
|
||||
p.literal("]</tool_call>")
|
||||
})
|
||||
);
|
||||
|
||||
// Parser accepts content, optionally followed by a tool call
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.optional(tool_call),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
For a more complete example, see `test_example_native()` in
|
||||
[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp).
|
||||
|
||||
## Parsers/Combinators
|
||||
|
||||
### Basic Matchers
|
||||
|
||||
- **`eps()`** - Matches nothing and always succeeds (epsilon/empty match)
|
||||
- **`start()`** - Matches the start of input (anchor `^`)
|
||||
- **`end()`** - Matches the end of input (anchor `$`)
|
||||
- **`literal(string)`** - Matches an exact literal string
|
||||
- **`any()`** - Matches any single character (`.`)
|
||||
|
||||
### Combinators
|
||||
|
||||
- **`sequence(...)`** - Matches parsers in order; all must succeed
|
||||
- **`choice(...)`** - Matches the first parser that succeeds from alternatives (ordered choice)
|
||||
- **`one_or_more(p)`** - Matches one or more repetitions (`+`)
|
||||
- **`zero_or_more(p)`** - Matches zero or more repetitions (`*`)
|
||||
- **`optional(p)`** - Matches zero or one occurrence (`?`)
|
||||
- **`repeat(p, min, max)`** - Matches between min and max repetitions (use `-1` for unbounded)
|
||||
- **`repeat(p, n)`** - Matches exactly n repetitions
|
||||
|
||||
### Lookahead
|
||||
|
||||
- **`peek(p)`** - Positive lookahead: succeeds if parser succeeds without consuming input (`&`)
|
||||
- **`negate(p)`** - Negative lookahead: succeeds if parser fails without consuming input (`!`)
|
||||
|
||||
### Character Classes & Utilities
|
||||
|
||||
- **`chars(classes, min, max)`** - Matches repetitions of characters from a character class
|
||||
- **`space()`** - Matches zero or more whitespace characters (space, tab, newline)
|
||||
- **`until(delimiter)`** - Matches characters until delimiter is found (delimiter not consumed)
|
||||
- **`until_one_of(delimiters)`** - Matches characters until any delimiter in the list is found
|
||||
- **`rest()`** - Matches everything remaining (`.*`)
|
||||
|
||||
### JSON Parsers
|
||||
|
||||
- **`json()`** - Complete JSON parser (objects, arrays, strings, numbers, booleans, null)
|
||||
- **`json_object()`** - JSON object parser
|
||||
- **`json_array()`** - JSON array parser
|
||||
- **`json_string()`** - JSON string parser
|
||||
- **`json_number()`** - JSON number parser
|
||||
- **`json_bool()`** - JSON boolean parser
|
||||
- **`json_null()`** - JSON null parser
|
||||
- **`json_string_content()`** - JSON string content without surrounding quotes
|
||||
- **`json_member(key, p)`** - JSON object member with specific key and value parser
|
||||
|
||||
### Grammar Building
|
||||
|
||||
- **`ref(name)`** - Creates a lightweight reference to a named rule (for recursive grammars)
|
||||
- **`rule(name, p, trigger)`** - Creates a named rule and returns a reference
|
||||
- **`trigger_rule(name, p)`** - Creates a trigger rule (entry point for lazy grammar generation)
|
||||
- **`schema(p, name, schema, raw)`** - Wraps parser with JSON schema metadata for grammar generation
|
||||
|
||||
### AST Control
|
||||
|
||||
- **`atomic(p)`** - Prevents AST node creation for partial parses
|
||||
- **`tag(tag, p)`** - Creates AST nodes with semantic tags (multiple nodes can share tags)
|
||||
|
||||
## GBNF Grammar Generation
|
||||
|
||||
The PEG parser also acts as a convenient DSL for generating GBNF grammars, with
|
||||
some exceptions.
|
||||
|
||||
```cpp
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(params.tools, [&](const json & fn) {
|
||||
builder.resolve_refs(fn.at("parameters"));
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
```
|
||||
|
||||
The notable exception is the `negate(p)` lookahead parser, which cannot be
|
||||
defined as a CFG grammar and therefore does not produce a rule. Its usage
|
||||
should be limited and preferably hidden behind a `schema()` parser. In many
|
||||
cases, `until(delimiter)` or `until_one_of(delimiters)` is a better choice.
|
||||
|
||||
Another limitation is that the PEG parser requires an unambiguous grammar. In
|
||||
contrast, the `llama-grammar` implementation can support ambiguous grammars,
|
||||
though they are difficult to parse.
|
||||
|
||||
### Lazy Grammars
|
||||
|
||||
During lazy grammar generation, only rules reachable from a `trigger_rule(p)`
|
||||
are emitted in the grammar. All trigger rules are added as alternations in the
|
||||
root rule. It is still necessary to define trigger patterns, as the parser has
|
||||
no interaction with the grammar sampling.
|
||||
|
||||
### JSON Schema
|
||||
|
||||
The `schema(p, name, schema, raw)` parser will use the `json-schema-to-grammar`
|
||||
implementation to generate the grammar instead of the underlying parser.
|
||||
|
||||
The `raw` option emits a grammar suitable for a raw string instead of a JSON
|
||||
string. In other words, it won't be wrapped in quotes or require escaping
|
||||
quotes. It should only be used when `type == "string"`.
|
||||
|
||||
The downside is that it can potentially lead to ambiguous grammars. For
|
||||
example, if a user provides the pattern `^.*$`, the following grammar may be
|
||||
generated:
|
||||
|
||||
```
|
||||
root ::= "<arg>" .* "</arg>"
|
||||
```
|
||||
|
||||
This creates an ambiguous grammar that cannot be parsed by the PEG parser. To
|
||||
help mitigate this, if `.*` is found in the pattern, the grammar from the
|
||||
underlying parser will be emitted instead.
|
||||
|
||||
## Common AST Shapes for Chat Parsing
|
||||
|
||||
Most model output can be placed in one of the following categories:
|
||||
|
||||
- Content only
|
||||
- Tool calling with arguments emitted as a single JSON object
|
||||
- Tool calling with arguments emitted as separate entities, either XML
|
||||
(Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2)
|
||||
|
||||
To provide broad coverage,
|
||||
[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and
|
||||
mappers that help create parsers and visitors/extractors for these types. They
|
||||
require parsers to tag nodes to conform to an AST "shape". This normalization
|
||||
makes it easy to extract information and generalize parsing.
|
||||
|
||||
### Simple
|
||||
|
||||
The `common_chat_peg_builder` builds a `simple` parser that supports
|
||||
content-only models with optional reasoning.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for extracting `reasoning_content`
|
||||
- **`content(p)`** - Tag node for extracting `content`
|
||||
|
||||
```cpp
|
||||
build_chat_peg_parser([&](common_chat_peg_parser & p) {
|
||||
return p.sequence({
|
||||
p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>"),
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
Use `common_chat_peg_mapper` to extract the content. Note that this is already
|
||||
done for you in `common_chat_peg_parser` when
|
||||
`chat_format == COMMON_CHAT_FORMAT_PEG_SIMPLE`.
|
||||
|
||||
```cpp
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
```
|
||||
|
||||
### Native
|
||||
|
||||
The `common_chat_peg_native_builder` builds a `native` parser suitable for
|
||||
models that emit tool arguments as a direct JSON object.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for `reasoning_content`
|
||||
- **`content(p)`** - Tag node for `content`
|
||||
- **`tool(p)`** - Tag entirety of a single tool call
|
||||
- **`tool_open(p)`** - Tag start of a tool call
|
||||
- **`tool_close(p)`** - Tag end of a tool call
|
||||
- **`tool_id(p)`** - Tag the tool call ID (optional)
|
||||
- **`tool_name(p)`** - Tag the tool name
|
||||
- **`tool_args(p)`** - Tag the tool arguments
|
||||
|
||||
```cpp
|
||||
build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) {
|
||||
auto get_weather_tool = p.tool(p.sequence({
|
||||
p.tool_open(p.literal("{")),
|
||||
p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""),
|
||||
p.literal(","),
|
||||
p.json_member("arguments", p.tool_args(p.json())),
|
||||
p.tool_close(p.literal("}"))
|
||||
}));
|
||||
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.literal("<tool_call>"),
|
||||
get_weather_tool,
|
||||
p.literal("</tool_call>"),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Constructed
|
||||
|
||||
The `common_chat_peg_constructed_builder` builds a `constructed` parser
|
||||
suitable for models that emit tool arguments as separate entities, such as XML
|
||||
tags.
|
||||
|
||||
- **`reasoning(p)`** - Tag node for `reasoning_content`
|
||||
- **`content(p)`** - Tag node for `content`
|
||||
- **`tool(p)`** - Tag entirety of a single tool call
|
||||
- **`tool_open(p)`** - Tag start of a tool call
|
||||
- **`tool_close(p)`** - Tag end of a tool call
|
||||
- **`tool_name(p)`** - Tag the tool name
|
||||
- **`tool_arg(p)`** - Tag a complete tool argument (name + value)
|
||||
- **`tool_arg_open(p)`** - Tag start of a tool argument
|
||||
- **`tool_arg_close(p)`** - Tag end of a tool argument
|
||||
- **`tool_arg_name(p)`** - Tag the argument name
|
||||
- **`tool_arg_string_value(p)`** - Tag string value for the argument
|
||||
- **`tool_arg_json_value(p)`** - Tag JSON value for the argument
|
||||
|
||||
```cpp
|
||||
build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
|
||||
auto location_arg = p.tool_arg(
|
||||
p.tool_arg_open("<parameter name=\"" + p.tool_arg_name(p.literal("location")) + "\">"),
|
||||
p.tool_arg_string_value(p.until("</parameter>")),
|
||||
p.tool_arg_close(p.literal("</parameter>"))
|
||||
);
|
||||
|
||||
auto get_weather_tool = p.tool(p.sequence({
|
||||
p.tool_open("<function name=\"" + p.tool_name(p.literal("get_weather")) + "\">"),
|
||||
location_arg,
|
||||
p.tool_close(p.literal("</function>"))
|
||||
}));
|
||||
|
||||
return p.sequence({
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.literal("<tool_call>"),
|
||||
get_weather_tool,
|
||||
p.literal("</tool_call>"),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
```
|
||||
+1
-1
@@ -226,7 +226,7 @@ option(GGML_WEBGPU "ggml: use WebGPU"
|
||||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
|
||||
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
|
||||
|
||||
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
|
||||
@@ -2698,6 +2698,11 @@ struct ggml_cplan ggml_graph_plan(
|
||||
n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
|
||||
}
|
||||
|
||||
#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
|
||||
// Emscripten without pthreads support can only use a single thread
|
||||
n_threads = 1;
|
||||
#endif
|
||||
|
||||
size_t work_size = 0;
|
||||
|
||||
struct ggml_cplan cplan;
|
||||
|
||||
@@ -50,7 +50,7 @@ void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, gg
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
|
||||
if (ppls->data.find(name) == ppls->data.end()) {
|
||||
if (ppls->data.find(name) == ppls->data.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
@@ -146,6 +146,8 @@ struct ggml_metal_library {
|
||||
id<MTLDevice> device;
|
||||
|
||||
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
||||
|
||||
NSLock * lock;
|
||||
};
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
@@ -296,9 +298,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -365,6 +368,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -380,20 +384,27 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
|
||||
ggml_metal_pipelines_free(lib->pipelines);
|
||||
|
||||
[lib->lock release];
|
||||
|
||||
free(lib);
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
return ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
[lib->lock lock];
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
|
||||
// note: the pipelines are cached in the library per device, so they are shared across all metal contexts
|
||||
ggml_critical_section_start();
|
||||
[lib->lock lock];
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
if (res) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -414,7 +425,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||
}
|
||||
if (!mtl_function) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
if (error) {
|
||||
@@ -433,7 +444,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
(int) res->obj.threadExecutionWidth);
|
||||
|
||||
if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
|
||||
|
||||
@@ -443,7 +454,7 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
|
||||
ggml_metal_pipelines_add(lib->pipelines, name, res);
|
||||
}
|
||||
|
||||
ggml_critical_section_end();
|
||||
[lib->lock unlock];
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -39,8 +39,23 @@ add_dependencies(ggml-webgpu generate_shaders)
|
||||
if(EMSCRIPTEN)
|
||||
set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
|
||||
|
||||
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
if(NOT EMDAWNWEBGPU_DIR)
|
||||
# default built-in port
|
||||
target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu")
|
||||
target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu")
|
||||
else()
|
||||
# custom port
|
||||
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
|
||||
endif()
|
||||
|
||||
if (GGML_WEBGPU_JSPI)
|
||||
target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions")
|
||||
target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions")
|
||||
else()
|
||||
target_compile_options(ggml-webgpu PRIVATE "-fexceptions")
|
||||
target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions")
|
||||
endif()
|
||||
else()
|
||||
find_package(Dawn REQUIRED)
|
||||
set(DawnWebGPU_TARGET dawn::webgpu_dawn)
|
||||
@@ -48,6 +63,9 @@ endif()
|
||||
|
||||
if (GGML_WEBGPU_DEBUG)
|
||||
target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
|
||||
if(EMSCRIPTEN)
|
||||
target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_WEBGPU_CPU_PROFILE)
|
||||
|
||||
@@ -9,6 +9,10 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
#endif
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include <atomic>
|
||||
@@ -261,9 +265,12 @@ struct webgpu_context_struct {
|
||||
wgpu::Queue queue;
|
||||
wgpu::Limits limits;
|
||||
|
||||
uint32_t subgroup_size;
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t subgroup_size;
|
||||
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
|
||||
#endif
|
||||
|
||||
// Separate this out from limits since on some Metal systems, the limit returned by
|
||||
// querying the limits is higher than the actual allowed maximum.
|
||||
@@ -449,8 +456,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct
|
||||
// If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
|
||||
// inflight_max may be 0, meaning that we must wait on all futures.
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
uint inflight_threads = ctx->inflight_threads;
|
||||
uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
|
||||
uint32_t inflight_threads = ctx->inflight_threads;
|
||||
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
|
||||
while (futures.size() >= inflight_max && futures.size() > 0) {
|
||||
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
|
||||
futures.erase(futures.begin());
|
||||
@@ -986,6 +993,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
|
||||
uint32_t wg_m;
|
||||
uint32_t wg_n;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
// The total number of subgroups/workgroups needed per matrix.
|
||||
uint32_t wg_m_sg_tile =
|
||||
@@ -995,11 +1003,15 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
|
||||
wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile;
|
||||
} else {
|
||||
#endif
|
||||
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s;
|
||||
wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
}
|
||||
#endif
|
||||
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
}
|
||||
}
|
||||
@@ -1419,9 +1431,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
commands.push_back(*cmd);
|
||||
}
|
||||
// compute the batch size based on the number of inflight threads
|
||||
uint inflight_threads = ctx->inflight_threads;
|
||||
uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
|
||||
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
|
||||
uint32_t inflight_threads = ctx->inflight_threads;
|
||||
uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
|
||||
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
|
||||
if (commands.size() >= batch_size) {
|
||||
futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
|
||||
// Process events and check for completed submissions
|
||||
@@ -1758,6 +1770,17 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
||||
|
||||
std::string proc_mul_mat_f32_f32;
|
||||
std::string proc_mul_mat_f32_f32_vec;
|
||||
std::string proc_mul_mat_f16_f32;
|
||||
std::string proc_mul_mat_f16_f32_vec;
|
||||
std::string proc_mul_mat_f16_f16;
|
||||
std::string proc_mul_mat_f16_f16_vec;
|
||||
std::string proc_mul_mat_q4_0_f32;
|
||||
std::string proc_mul_mat_q4_0_f32_vec;
|
||||
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_constants;
|
||||
#ifndef __EMSCRIPTEN__
|
||||
if (webgpu_ctx->supports_subgroup_matrix) {
|
||||
std::map<std::string, std::string> sg_matrix_repls;
|
||||
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
|
||||
@@ -1770,100 +1793,57 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
|
||||
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
|
||||
|
||||
std::string proc_mul_mat_subgroup_matrix_f32_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f32_f32_vec =
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f32_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f32_vec =
|
||||
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f16 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
|
||||
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
||||
proc_mul_mat_f16_f16_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_q4_0_f32 =
|
||||
proc_mul_mat_q4_0_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
|
||||
std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec =
|
||||
proc_mul_mat_q4_0_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f32_f32_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f16_f32_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_f16_f16_vec");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32");
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(),
|
||||
"mul_mat_subgroup_matrix_q4_0_f32_vec");
|
||||
} else {
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants(3);
|
||||
mul_mat_reg_tile_constants[0].key = "TILE_K";
|
||||
mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K;
|
||||
mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M";
|
||||
mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N";
|
||||
mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
#endif
|
||||
mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
|
||||
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
|
||||
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
|
||||
|
||||
std::map<std::string, std::string> reg_repls;
|
||||
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
|
||||
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
|
||||
|
||||
// Process each reg-tile shader with tile replacements.
|
||||
// Keep the processed strings in-scope so .c_str() remains valid.
|
||||
std::string proc_mul_mat_reg_tile_f32_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f32_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f16 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_f16_f16_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_q4_0_f32 =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
||||
std::string proc_mul_mat_reg_tile_q4_0_f32_vec =
|
||||
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(),
|
||||
"mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(),
|
||||
"mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(),
|
||||
"mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(),
|
||||
"mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(),
|
||||
"mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(),
|
||||
"mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants);
|
||||
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
||||
proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
||||
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
||||
proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
||||
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
||||
proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
||||
proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
||||
proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
||||
#ifndef __EMSCRIPTEN__
|
||||
}
|
||||
#endif
|
||||
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
|
||||
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||
webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
|
||||
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
|
||||
@@ -2384,13 +2364,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
|
||||
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
||||
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
|
||||
const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
|
||||
wgpu::DawnTogglesDescriptor adapterTogglesDesc;
|
||||
adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
|
||||
adapterTogglesDesc.enabledToggleCount = 2;
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
options.nextInChain = &adapterTogglesDesc;
|
||||
#endif
|
||||
|
||||
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
@@ -2406,11 +2390,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
ctx->adapter.GetLimits(&ctx->limits);
|
||||
ctx->max_wg_size_x = 288; // default value
|
||||
|
||||
wgpu::AdapterInfo info{};
|
||||
wgpu::AdapterInfo info{};
|
||||
#ifndef __EMSCRIPTEN__
|
||||
wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
|
||||
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
info.nextInChain = &subgroup_matrix_configs;
|
||||
}
|
||||
#endif
|
||||
ctx->adapter.GetInfo(&info);
|
||||
|
||||
wgpu::SupportedFeatures features;
|
||||
@@ -2418,6 +2404,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
// we require f16 support
|
||||
GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Only support square f16 matrices of size 8 or 16 for now
|
||||
bool valid_subgroup_matrix_config = false;
|
||||
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||
@@ -2433,36 +2420,27 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
}
|
||||
}
|
||||
|
||||
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
#endif
|
||||
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||
ctx->subgroup_size = info.subgroupMaxSize;
|
||||
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||
ctx->subgroup_size = info.subgroupMaxSize;
|
||||
|
||||
// Initialize device
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
|
||||
if (ctx->supports_subgroup_matrix) {
|
||||
required_features.push_back(wgpu::FeatureName::Subgroups);
|
||||
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
||||
#endif
|
||||
|
||||
// Enable Dawn-specific toggles to increase native performance
|
||||
// TODO: Don't enable for WASM builds, they won't have an effect anyways
|
||||
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||
// only for native performance?
|
||||
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||
deviceTogglesDesc.enabledToggleCount = 4;
|
||||
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||
deviceTogglesDesc.disabledToggleCount = 1;
|
||||
|
||||
wgpu::DeviceDescriptor dev_desc;
|
||||
dev_desc.requiredLimits = &ctx->limits;
|
||||
dev_desc.requiredFeatures = required_features.data();
|
||||
@@ -2480,7 +2458,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||
std::string(message).c_str());
|
||||
});
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Enable Dawn-specific toggles to increase native performance
|
||||
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||
// only for native performance?
|
||||
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
||||
"disable_polyfills_on_integer_div_and_mod" };
|
||||
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||
deviceTogglesDesc.enabledToggleCount = 4;
|
||||
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||
deviceTogglesDesc.disabledToggleCount = 1;
|
||||
|
||||
dev_desc.nextInChain = &deviceTogglesDesc;
|
||||
#endif
|
||||
|
||||
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||
@@ -2578,18 +2572,27 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 1;
|
||||
|
||||
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
|
||||
|
||||
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
|
||||
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
|
||||
instanceTogglesDesc.enabledToggleCount = 1;
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
instance_descriptor.requiredFeatures = instance_features.data();
|
||||
instance_descriptor.requiredFeatureCount = instance_features.size();
|
||||
instance_descriptor.nextInChain = &instanceTogglesDesc;
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
|
||||
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
|
||||
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
|
||||
instanceTogglesDesc.enabledToggleCount = 1;
|
||||
instance_descriptor.nextInChain = &instanceTogglesDesc;
|
||||
#endif
|
||||
|
||||
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
if (webgpu_ctx->instance == nullptr) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(webgpu_ctx->instance != nullptr);
|
||||
|
||||
static ggml_backend_reg reg = {
|
||||
|
||||
+1
-1
@@ -1169,7 +1169,7 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo
|
||||
struct gguf_writer_base {
|
||||
size_t written_bytes {0u};
|
||||
|
||||
~gguf_writer_base(void) {}
|
||||
~gguf_writer_base(void) = default;
|
||||
|
||||
// we bet on devirtualization
|
||||
virtual void write(int8_t val) = 0;
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
const http = require('http');
|
||||
const fs = require('fs').promises;
|
||||
const path = require('path');
|
||||
|
||||
// This file is used for testing wasm build from emscripten
|
||||
// Example build command:
|
||||
// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF
|
||||
// cmake --build build-wasm --target test-backend-ops -j
|
||||
|
||||
const PORT = 8080;
|
||||
const STATIC_DIR = path.join(__dirname, '../build-wasm/bin');
|
||||
console.log(`Serving static files from: ${STATIC_DIR}`);
|
||||
|
||||
const mimeTypes = {
|
||||
'.html': 'text/html',
|
||||
'.js': 'text/javascript',
|
||||
'.css': 'text/css',
|
||||
'.png': 'image/png',
|
||||
'.jpg': 'image/jpeg',
|
||||
'.gif': 'image/gif',
|
||||
'.svg': 'image/svg+xml',
|
||||
'.json': 'application/json',
|
||||
'.woff': 'font/woff',
|
||||
'.woff2': 'font/woff2',
|
||||
};
|
||||
|
||||
async function generateDirListing(dirPath, reqUrl) {
|
||||
const files = await fs.readdir(dirPath);
|
||||
let html = `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Directory Listing</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; padding: 20px; }
|
||||
ul { list-style: none; padding: 0; }
|
||||
li { margin: 5px 0; }
|
||||
a { text-decoration: none; color: #0066cc; }
|
||||
a:hover { text-decoration: underline; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Directory: ${reqUrl}</h1>
|
||||
<ul>
|
||||
`;
|
||||
|
||||
if (reqUrl !== '/') {
|
||||
html += `<li><a href="../">../ (Parent Directory)</a></li>`;
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
const filePath = path.join(dirPath, file);
|
||||
const stats = await fs.stat(filePath);
|
||||
const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : '');
|
||||
html += `<li><a href="${link}">${file}${stats.isDirectory() ? '/' : ''}</a></li>`;
|
||||
}
|
||||
|
||||
html += `
|
||||
</ul>
|
||||
</body>
|
||||
</html>
|
||||
`;
|
||||
return html;
|
||||
}
|
||||
|
||||
const server = http.createServer(async (req, res) => {
|
||||
try {
|
||||
// Set COOP and COEP headers
|
||||
res.setHeader('Cross-Origin-Opener-Policy', 'same-origin');
|
||||
res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp');
|
||||
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate');
|
||||
res.setHeader('Pragma', 'no-cache');
|
||||
res.setHeader('Expires', '0');
|
||||
|
||||
const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url));
|
||||
const stats = await fs.stat(filePath);
|
||||
|
||||
if (stats.isDirectory()) {
|
||||
const indexPath = path.join(filePath, 'index.html');
|
||||
try {
|
||||
const indexData = await fs.readFile(indexPath);
|
||||
res.writeHeader(200, { 'Content-Type': 'text/html' });
|
||||
res.end(indexData);
|
||||
} catch {
|
||||
// No index.html, generate directory listing
|
||||
const dirListing = await generateDirListing(filePath, req.url);
|
||||
res.writeHeader(200, { 'Content-Type': 'text/html' });
|
||||
res.end(dirListing);
|
||||
}
|
||||
} else {
|
||||
const ext = path.extname(filePath).toLowerCase();
|
||||
const contentType = mimeTypes[ext] || 'application/octet-stream';
|
||||
const data = await fs.readFile(filePath);
|
||||
res.writeHeader(200, { 'Content-Type': contentType });
|
||||
res.end(data);
|
||||
}
|
||||
} catch (err) {
|
||||
if (err.code === 'ENOENT') {
|
||||
res.writeHeader(404, { 'Content-Type': 'text/plain' });
|
||||
res.end('404 Not Found');
|
||||
} else {
|
||||
res.writeHeader(500, { 'Content-Type': 'text/plain' });
|
||||
res.end('500 Internal Server Error');
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.listen(PORT, () => {
|
||||
console.log(`Server running at http://localhost:${PORT}/`);
|
||||
});
|
||||
+1
-1
@@ -37,7 +37,7 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
||||
template <typename T>
|
||||
struct no_init {
|
||||
T value;
|
||||
no_init() { /* do nothing */ }
|
||||
no_init() = default;
|
||||
};
|
||||
|
||||
struct time_meas {
|
||||
|
||||
+3
-3
@@ -423,8 +423,8 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s
|
||||
}
|
||||
|
||||
struct llama_model::impl {
|
||||
impl() {}
|
||||
~impl() {}
|
||||
impl() = default;
|
||||
~impl() = default;
|
||||
|
||||
uint64_t n_elements = 0;
|
||||
|
||||
@@ -461,7 +461,7 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi
|
||||
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
|
||||
}
|
||||
|
||||
llama_model::~llama_model() {}
|
||||
llama_model::~llama_model() = default;
|
||||
|
||||
void llama_model::load_stats(llama_model_loader & ml) {
|
||||
pimpl->n_elements = ml.n_elements;
|
||||
|
||||
+1
-2
@@ -3253,8 +3253,7 @@ void llama_vocab::impl::print_info() const {
|
||||
llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
|
||||
}
|
||||
|
||||
llama_vocab::~llama_vocab() {
|
||||
}
|
||||
llama_vocab::~llama_vocab() = default;
|
||||
|
||||
void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
pimpl->load(ml, kv);
|
||||
|
||||
@@ -3,3 +3,4 @@
|
||||
*.o
|
||||
ggml-common.h
|
||||
**/*.swp
|
||||
!peg-parser
|
||||
|
||||
+18
-2
@@ -1,13 +1,15 @@
|
||||
llama_add_compile_flags()
|
||||
|
||||
function(llama_build source)
|
||||
set(TEST_SOURCES ${source} ${ARGN})
|
||||
|
||||
if (DEFINED LLAMA_TEST_NAME)
|
||||
set(TEST_TARGET ${LLAMA_TEST_NAME})
|
||||
else()
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${source})
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCES})
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
endfunction()
|
||||
@@ -83,6 +85,8 @@ function(llama_build_and_test source)
|
||||
set(multiValueArgs ARGS)
|
||||
cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
set(TEST_SOURCES ${source} ${LLAMA_TEST_UNPARSED_ARGUMENTS} get-model.cpp)
|
||||
|
||||
if (NOT DEFINED LLAMA_TEST_LABEL)
|
||||
set(LLAMA_TEST_LABEL "main")
|
||||
endif()
|
||||
@@ -95,7 +99,7 @@ function(llama_build_and_test source)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${source} get-model.cpp)
|
||||
add_executable(${TEST_TARGET} ${TEST_SOURCES})
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
|
||||
@@ -180,9 +184,21 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
endif()
|
||||
|
||||
llama_build_and_test(test-chat-parser.cpp)
|
||||
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-json-partial.cpp)
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(
|
||||
test-peg-parser.cpp
|
||||
peg-parser/simple-tokenize.cpp
|
||||
peg-parser/test-basic.cpp
|
||||
peg-parser/test-gbnf-generation.cpp
|
||||
peg-parser/test-json-parser.cpp
|
||||
peg-parser/test-json-serialization.cpp
|
||||
peg-parser/test-unicode.cpp
|
||||
peg-parser/testing.h
|
||||
peg-parser/tests.h
|
||||
)
|
||||
llama_build_and_test(test-regex-partial.cpp)
|
||||
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
#include "simple-tokenize.h"
|
||||
|
||||
std::vector<std::string> simple_tokenize(const std::string & input) {
|
||||
std::vector<std::string> result;
|
||||
std::string current;
|
||||
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
switch (input[i]) {
|
||||
case ' ':
|
||||
case '\n':
|
||||
case '\t':
|
||||
case '{':
|
||||
case '}':
|
||||
case ',':
|
||||
case '[':
|
||||
case '"':
|
||||
case ']':
|
||||
case '.':
|
||||
case '<':
|
||||
case '>':
|
||||
case '=':
|
||||
case '/':
|
||||
if (!current.empty()) {
|
||||
result.push_back(current);
|
||||
current.clear();
|
||||
}
|
||||
default:;
|
||||
}
|
||||
current += input[i];
|
||||
}
|
||||
|
||||
if (!current.empty()) {
|
||||
result.push_back(current);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
std::vector<std::string> simple_tokenize(const std::string &);
|
||||
@@ -0,0 +1,454 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_basic(testing & t) {
|
||||
t.test("chars", [](testing & t) {
|
||||
// Test common escape sequences - newline
|
||||
t.test("escape_sequence_newline", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\n");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_newline", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - tab
|
||||
t.test("escape_sequence_tab", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\t");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_tab", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - backslash
|
||||
t.test("escape_sequence_backslash", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("\\");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_backslash", true, result.success());
|
||||
});
|
||||
|
||||
// Test common escape sequences - space (should ())
|
||||
t.test("escape_sequence_space_fail", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[\\n\\t\\\\]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context(" ");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escape_sequence_space_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'a' should succeed
|
||||
t.test("escaped_dash_a", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("a");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_a", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - '-' should succeed (literal dash)
|
||||
t.test("escaped_dash_literal", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("-");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_literal", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'z' should succeed
|
||||
t.test("escaped_dash_z", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("z");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_z", true, result.success());
|
||||
});
|
||||
|
||||
// Test escaped dash - 'b' should NOT match (since \- is literal dash, not range)
|
||||
t.test("escaped_dash_b_fail", [](testing &t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("[a\\-z]"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("b");
|
||||
result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("escaped_dash_b_fail", true, result.fail());
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
t.test("optional", [](testing & t) {
|
||||
// Full match with optional part present
|
||||
t.test("optional_present", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello world");
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("optional_present", true, result.success());
|
||||
t.assert_equal("optional_present_end", 11u, result.end);
|
||||
});
|
||||
|
||||
// Full match with optional part absent
|
||||
t.test("optional_absent", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("optional_absent", true, result.success());
|
||||
t.assert_equal("optional_absent_end", 5u, result.end);
|
||||
});
|
||||
|
||||
// Partial match - waiting for more input to determine if optional matches
|
||||
t.test("partial_match_need_more", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto ctx = common_peg_parse_context("hello ", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("partial_match_need_more", true, result.need_more_input());
|
||||
});
|
||||
});
|
||||
|
||||
t.test("partial parsing", [](testing & t) {
|
||||
// Literals - Basic Success
|
||||
t.test("literal_success", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("hello");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("literal_success", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes - Basic Lowercase Success
|
||||
t.test("char_class_lowercase_success", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("a");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_lowercase_success", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes - Uppercase Fail
|
||||
t.test("char_class_uppercase_fail", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("A");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_uppercase_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Lowercase Success
|
||||
t.test("char_class_with_dash_lowercase", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("f");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_lowercase", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Literal Dash Success
|
||||
t.test("char_class_with_dash_literal_dash", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("-");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_literal_dash", true, result.success());
|
||||
});
|
||||
|
||||
// Char Classes with Dash - Uppercase Fail
|
||||
t.test("char_class_with_dash_uppercase_fail", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.chars("a-z-"); });
|
||||
|
||||
common_peg_parse_context ctx;
|
||||
common_peg_parse_result result;
|
||||
|
||||
ctx = common_peg_parse_context("A");
|
||||
result = parser.parse(ctx);
|
||||
t.assert_equal("char_class_with_dash_uppercase_fail", true, result.fail());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 1
|
||||
t.test("sequence_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<thi", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 2
|
||||
t.test("sequence_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("begin") + p.literal("end"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("begin", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Partial Match 3
|
||||
t.test("sequence_partial_match_3", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<think></", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_partial_match_3", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Sequences - Full Match
|
||||
t.test("sequence_full_match", [&](testing & t) {
|
||||
auto common_chat_combinator_parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("hello") + p.literal("world"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("helloworld", false);
|
||||
auto result = common_chat_combinator_parser.parse(ctx);
|
||||
t.assert_equal("sequence_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// Sequences - No Match
|
||||
t.test("sequence_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("<think>") + p.literal("</think>"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("<think>I am common_chat_combinator_parser", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("sequence_no_match", true, result.fail());
|
||||
});
|
||||
|
||||
// Choices - Partial Match 1
|
||||
t.test("choices_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("option1") | p.literal("option2"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("opt", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Choices - Partial Match 2
|
||||
t.test("choices_partial_match_2", [&](testing & t) {
|
||||
auto parser =
|
||||
build_peg_parser([](common_peg_parser_builder & p) { return p.literal("choice_a") | p.literal("choice_b"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("choice", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Choices - Full Match 1
|
||||
t.test("choices_full_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("first") | p.literal("second"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("first", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_full_match_1", true, result.success());
|
||||
});
|
||||
|
||||
// Choices - Full Match 2
|
||||
t.test("choices_full_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("alpha") | p.literal("beta"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("beta", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_full_match_2", true, result.success());
|
||||
});
|
||||
|
||||
// Choices - No Match
|
||||
t.test("choices_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.literal("good") | p.literal("better"); });
|
||||
|
||||
auto ctx = common_peg_parse_context("best", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("choices_no_match", true, result.fail());
|
||||
});
|
||||
|
||||
// Zero or More - Partial Match 1
|
||||
t.test("zero_or_more_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("ab")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("a", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Zero or More - Partial Match 2
|
||||
t.test("zero_or_more_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("xy")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("xyx", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Zero or More - Full Match
|
||||
t.test("zero_or_more_full_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.zero_or_more(p.literal("test")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("test", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("zero_or_more_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// One or More - Partial Match 1
|
||||
t.test("one_or_more_partial_match_1", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("repeat")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("rep", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_partial_match_1", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// One or More - Partial Match 2
|
||||
t.test("one_or_more_partial_match_2", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("ab")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("aba", true);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_partial_match_2", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// One or More - Full Match
|
||||
t.test("one_or_more_full_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("single")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("single", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_full_match", true, result.success());
|
||||
});
|
||||
|
||||
// One or More - No Match
|
||||
t.test("one_or_more_no_match", [&](testing & t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.one_or_more(p.literal("()")); });
|
||||
|
||||
auto ctx = common_peg_parse_context("success", false);
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_equal("one_or_more_no_match", true, result.fail());
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
t.test("recursive rules", [](testing &t) {
|
||||
// Test simple number
|
||||
t.test("simple_number", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("1", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test simple list
|
||||
t.test("simple_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[1]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test nested list
|
||||
t.test("nested_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[2]]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test deeply nested list
|
||||
t.test("deeply_nested_list", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[[3]]]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
});
|
||||
|
||||
// Test need_more_input match
|
||||
t.test("need_more_input_match", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[[", true);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test no match
|
||||
t.test("no_match", [](testing &t) {
|
||||
auto value_parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("number", p.chars("0-9"));
|
||||
p.rule("list", p.literal("[") + p.ref("value") + p.literal("]"));
|
||||
return p.rule("value", p.ref("number") | p.ref("list"));
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx("[a]", false);
|
||||
auto result = value_parser.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_fail", true, result.fail());
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
#include "tests.h"
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
#include <regex>
|
||||
|
||||
static std::string trim_leading_space(const std::string & s) {
|
||||
static const std::regex leading_ws_re = std::regex(R"((^|\n)\s+)");
|
||||
return std::regex_replace(s, leading_ws_re, "$1");
|
||||
}
|
||||
|
||||
static void assert_gbnf_equal(testing & t, const std::string & expected, const std::string & actual) {
|
||||
t.assert_equal("gbnf are equal", trim_leading_space(expected), trim_leading_space(actual));
|
||||
}
|
||||
|
||||
void test_gbnf_generation(testing &t) {
|
||||
t.test("literal grammar generation", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("char class grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.chars("[a-z]", 1, 1);
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= [a-z]
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("sequence grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.literal(" ") + p.literal("world");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" " " "world"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("choice grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("cat") | p.literal("dog");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "cat" | "dog"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("one_or_more grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.one_or_more(p.literal("a"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("zero_or_more grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.zero_or_more(p.literal("a"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "a"*
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("optional grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") + p.optional(p.literal(" world"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" " world"?
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])*
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("complex expressions with parentheses", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.one_or_more(p.literal("a") | p.literal("b"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ("a" | "b")+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("rule references", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
auto digit = p.rule("digit", p.chars("[0-9]", 1, 1));
|
||||
return p.one_or_more(digit);
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
digit ::= [0-9]
|
||||
root ::= digit+
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("escaping in literals", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello\nworld\n!");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello\nworld\n!"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("operator<< (whitespace insertion)", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.literal("hello") << p.literal("world");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= "hello" space "world"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("emit only reachable rules", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
p.rule("orphan", p.literal("orphan"));
|
||||
return p.literal("hello") + p.rule("child", p.literal(" world"));
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
child ::= " world"
|
||||
root ::= "hello" child
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("emit only trigger rules (and references)", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
auto rule1 = p.rule("rule-1", p.literal("a") + p.ref("rule-2"));
|
||||
p.rule("rule-2", p.literal("b") + p.ref("rule-3"), true);
|
||||
p.rule("rule-3", p.literal("c") + p.ref("rule-4"));
|
||||
p.rule("rule-4", p.literal("d"), true);
|
||||
return rule1;
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= rule-1
|
||||
rule-1 ::= "a" rule-2
|
||||
rule-2 ::= "b" rule-3
|
||||
rule-3 ::= "c" rule-4
|
||||
rule-4 ::= "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
|
||||
auto gbnf_lazy = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder, true);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= rule-2 | rule-4
|
||||
rule-2 ::= "b" rule-3
|
||||
rule-3 ::= "c" rule-4
|
||||
rule-4 ::= "d"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf_lazy);
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_json_parser(testing &t) {
|
||||
// Test parsing a simple JSON object
|
||||
t.test("simple JSON object parsing", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"name": "test", "value": 42, "flag": true})";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing a JSON array with mixed types
|
||||
t.test("JSON array with mixed types", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"([1, "hello", true, null, 3.14])";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test parsing nested JSON with objects and arrays
|
||||
t.test("nested JSON with objects and arrays", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input =
|
||||
R"({"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], "count": 2, "metadata": {"version": "1.0", "tags": ["admin", "user"]}})";
|
||||
common_peg_parse_context ctx(input);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_success", true, result.success());
|
||||
t.assert_equal("result_end", input.size(), result.end);
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete object
|
||||
t.test("need_more_input() parsing - incomplete object", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"name": "test", "value": )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete array
|
||||
t.test("need_more_input() parsing - incomplete array", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"([1, 2, 3, )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
// Test need_more_input() parsing - incomplete nested structure
|
||||
t.test("need_more_input() parsing - incomplete nested structure", [](testing &t) {
|
||||
auto json = build_peg_parser([](common_peg_parser_builder & p) { return p.json(); });
|
||||
|
||||
std::string input = R"({"data": {"nested": )";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = json.parse(ctx);
|
||||
|
||||
t.assert_equal("result_is_need_more_input", true, result.need_more_input());
|
||||
});
|
||||
|
||||
t.test("object member", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.json_member("name", "\"" + p.chars("[a-z]") + "\"");
|
||||
});
|
||||
|
||||
t.test("success", [&](testing &t) {
|
||||
std::string input = R"("name": "bob")";
|
||||
common_peg_parse_context ctx(input, false);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("success", result.success());
|
||||
});
|
||||
|
||||
t.test("partial", [&](testing &t) {
|
||||
std::string input = R"("name": "bo)";
|
||||
common_peg_parse_context ctx(input, true);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("need more input", result.need_more_input());
|
||||
});
|
||||
|
||||
t.test("failed", [&](testing &t) {
|
||||
std::string input = R"([])";
|
||||
common_peg_parse_context ctx(input, false);
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
t.assert_true("fail", result.fail());
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
#include "tests.h"
|
||||
|
||||
void test_json_serialization(testing &t) {
|
||||
auto original = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return "<tool_call>" + p.json() + "</tool_call>";
|
||||
});
|
||||
|
||||
auto json_serialized = original.to_json().dump();
|
||||
|
||||
t.test("compare before/after", [&](testing &t) {
|
||||
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
|
||||
|
||||
// Test complex JSON
|
||||
std::string input = R"({"name": "test", "values": [1, 2, 3], "nested": {"a": true}})";
|
||||
common_peg_parse_context ctx1(input);
|
||||
common_peg_parse_context ctx2(input);
|
||||
|
||||
auto result1 = original.parse(ctx1);
|
||||
auto result2 = deserialized.parse(ctx2);
|
||||
|
||||
t.assert_equal("both_succeed", result1.success(), result2.success());
|
||||
t.assert_equal("same_end_pos", result1.end, result2.end);
|
||||
});
|
||||
|
||||
t.bench("deserialize", [&]() {
|
||||
auto deserialized = common_peg_arena::from_json(nlohmann::json::parse(json_serialized));
|
||||
}, 100);
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
#include "tests.h"
|
||||
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <cctype>
|
||||
|
||||
static void assert_result_equal(testing & t, common_peg_parse_result_type expected, common_peg_parse_result_type actual) {
|
||||
t.assert_equal(common_peg_parse_result_type_name(expected), common_peg_parse_result_type_name(actual));
|
||||
}
|
||||
|
||||
static std::string hex_dump(const std::string& str) {
|
||||
std::ostringstream oss;
|
||||
for (unsigned char c : str) {
|
||||
if (std::isprint(c)) {
|
||||
oss << c;
|
||||
} else {
|
||||
oss << "\\x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c);
|
||||
}
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void test_unicode(testing &t) {
|
||||
struct test_case {
|
||||
std::string input;
|
||||
std::string expected_text;
|
||||
common_peg_parse_result_type expected_result;
|
||||
};
|
||||
|
||||
t.test("any", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Valid UTF-8 sequences
|
||||
{"Hello", "Hello", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("Caf\xC3\xA9"), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
{std::string("\xF0\x9F\x9A\x80"), std::string("\xF0\x9F\x9A\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Incomplete UTF-8 sequences (partial bytes at end)
|
||||
{std::string("Caf\xC3"), "Caf", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xE4\xBD"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xF0\x9F\x9A"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Invalid/malformed UTF-8 sequences
|
||||
{std::string("\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("Hello\x80World"), "Hello", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.one_or_more(p.any()), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("char classes", [](testing &t) {
|
||||
t.test("unicode range U+4E00-U+9FFF (CJK)", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Within range - CJK Unified Ideographs
|
||||
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
|
||||
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
|
||||
{std::string("\xE5\xA5\xBD"), std::string("\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+597D
|
||||
{std::string("\xE9\xBF\xBF"), std::string("\xE9\xBF\xBF"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+9FFF
|
||||
|
||||
// Outside range - should fail
|
||||
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL}, // ASCII
|
||||
{std::string("\xE4\xB7\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+4DFF (before range)
|
||||
{std::string("\xEA\x80\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+A000 (after range)
|
||||
|
||||
// Incomplete sequences in range
|
||||
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+4E00
|
||||
{std::string("\xE5\xA5"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete U+597D
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\u4E00-\u9FFF])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("unicode range U+1F600-U+1F64F (emoticons)", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Within range - Emoticons (all 4-byte UTF-8)
|
||||
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
|
||||
{std::string("\xF0\x9F\x98\x81"), std::string("\xF0\x9F\x98\x81"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F601
|
||||
{std::string("\xF0\x9F\x99\x8F"), std::string("\xF0\x9F\x99\x8F"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F64F
|
||||
|
||||
// Outside range
|
||||
{std::string("\xF0\x9F\x97\xBF"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F5FF (before range)
|
||||
{std::string("\xF0\x9F\x99\x90"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F650 (after range)
|
||||
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680 (outside range)
|
||||
|
||||
// Incomplete sequences
|
||||
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Incomplete emoji
|
||||
{std::string("\xF0\x9F"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT}, // Very incomplete
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\U0001F600-\U0001F64F])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("mixed unicode ranges", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Match CJK
|
||||
{std::string("\xE4\xB8\x80"), std::string("\xE4\xB8\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4E00
|
||||
{std::string("\xE4\xBD\xA0"), std::string("\xE4\xBD\xA0"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+4F60
|
||||
|
||||
// Match emoticons
|
||||
{std::string("\xF0\x9F\x98\x80"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS}, // U+1F600
|
||||
|
||||
// Match ASCII digits
|
||||
{"5", "5", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Don't match outside any range
|
||||
{"a", "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
{std::string("\xF0\x9F\x9A\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL}, // U+1F680
|
||||
|
||||
// Incomplete
|
||||
{std::string("\xE4\xB8"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
{std::string("\xF0\x9F\x98"), "", COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.chars(R"([\u4E00-\u9FFF\U0001F600-\U0001F64F0-9])"), p.end()});
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
// Assert result type matches
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
// Assert matched text if success or need_more_input
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
t.test("until parser", [](testing &t) {
|
||||
t.test("ASCII delimiter with Unicode content", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// CJK characters before delimiter
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD</tag>"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Emoji before delimiter
|
||||
{std::string("\xF0\x9F\x98\x80</tag>"), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mixed content
|
||||
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!</tag>"), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("incomplete UTF-8 at end", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Incomplete emoji at end, no delimiter
|
||||
{std::string("content\xF0\x9F\x98"), std::string("content"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete CJK at end, no delimiter
|
||||
{std::string("hello\xE4\xB8"), std::string("hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Complete content, no delimiter (should consume all valid UTF-8)
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success() || result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("malformed UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Invalid UTF-8 bytes
|
||||
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Continuation byte without lead byte
|
||||
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Invalid continuation byte
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.until("</tag>");
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
t.test("json_string parser", [](testing &t) {
|
||||
t.test("valid UTF-8 characters", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// ASCII only
|
||||
{"Hello World\"", "Hello World", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 2-byte UTF-8 (accented characters)
|
||||
{std::string("Caf\xC3\xA9\""), std::string("Caf\xC3\xA9"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 3-byte UTF-8 (CJK)
|
||||
{std::string("\xE4\xBD\xA0\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// 4-byte UTF-8 (emoji)
|
||||
{std::string("\xF0\x9F\x98\x80\""), std::string("\xF0\x9F\x98\x80"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mixed content
|
||||
{std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!\""), std::string("Hello \xE4\xB8\x96\xE7\x95\x8C!"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.json_string_content(), p.literal("\"")});
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("incomplete UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Incomplete 2-byte sequence
|
||||
{std::string("Caf\xC3"), std::string("Caf"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete 3-byte sequence
|
||||
{std::string("Hello\xE4\xB8"), std::string("Hello"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete 4-byte sequence
|
||||
{std::string("Text\xF0\x9F\x98"), std::string("Text"), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
|
||||
// Incomplete at very start
|
||||
{std::string("\xE4\xBD"), std::string(""), COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.json_string_content();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, true);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.need_more_input()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start);
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("malformed UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Invalid UTF-8 bytes
|
||||
{std::string("Hello\xFF\xFE"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Continuation byte without lead byte
|
||||
{std::string("Hello\x80World"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Invalid continuation byte
|
||||
{std::string("\xC3\x28"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
|
||||
// Overlong encoding (security issue)
|
||||
{std::string("\xC0\x80"), "", COMMON_PEG_PARSE_RESULT_FAIL},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.json_string_content();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
t.test("escape sequences with UTF-8", [](testing &t) {
|
||||
std::vector<test_case> test_cases {
|
||||
// Unicode escape sequence
|
||||
{"Hello\\u0041\"", "Hello\\u0041", COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Mix of UTF-8 and escape sequences
|
||||
{std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\n\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
|
||||
// Escaped quote in UTF-8 string
|
||||
{std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD\""), std::string("\xE4\xBD\xA0\\\"\xE5\xA5\xBD"), COMMON_PEG_PARSE_RESULT_SUCCESS},
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < test_cases.size(); i++) {
|
||||
const auto & tc = test_cases[i];
|
||||
std::string test_name = "case " + std::to_string(i) + ": " + hex_dump(tc.input);
|
||||
|
||||
t.test(test_name, [&](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder& p) {
|
||||
return p.sequence({p.json_string_content(), p.literal("\"")});
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
assert_result_equal(t, tc.expected_result, result.type);
|
||||
|
||||
if (result.success()) {
|
||||
std::string matched = tc.input.substr(result.start, result.end - result.start - 1); // -1 to exclude closing quote
|
||||
t.assert_equal(tc.expected_text, matched);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <vector>
|
||||
|
||||
struct testing {
|
||||
std::ostream &out;
|
||||
std::vector<std::string> stack;
|
||||
std::regex filter;
|
||||
bool filter_tests = false;
|
||||
bool throw_exception = false;
|
||||
bool verbose = false;
|
||||
int tests = 0;
|
||||
int assertions = 0;
|
||||
int failures = 0;
|
||||
int unnamed = 0;
|
||||
int exceptions = 0;
|
||||
|
||||
static constexpr std::size_t status_column = 80;
|
||||
|
||||
explicit testing(std::ostream &os = std::cout) : out(os) {}
|
||||
|
||||
std::string indent() const {
|
||||
if (stack.empty()) {
|
||||
return "";
|
||||
}
|
||||
return std::string((stack.size() - 1) * 2, ' ');
|
||||
}
|
||||
|
||||
std::string full_name() const {
|
||||
return string_join(stack, ".");
|
||||
}
|
||||
|
||||
void log(const std::string & msg) {
|
||||
if (verbose) {
|
||||
out << indent() << " " << msg << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void set_filter(const std::string & re) {
|
||||
filter = std::regex(re);
|
||||
filter_tests = true;
|
||||
}
|
||||
|
||||
bool should_run() const {
|
||||
if (filter_tests) {
|
||||
if (!std::regex_match(full_name(), filter)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void run_with_exceptions(F &&f, const char *ctx) {
|
||||
try {
|
||||
f();
|
||||
} catch (const std::exception &e) {
|
||||
++failures;
|
||||
++exceptions;
|
||||
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): " << e.what() << "\n";
|
||||
if (throw_exception) {
|
||||
throw;
|
||||
}
|
||||
} catch (...) {
|
||||
++failures;
|
||||
++exceptions;
|
||||
out << indent() << "UNHANDLED EXCEPTION (" << ctx << "): unknown\n";
|
||||
if (throw_exception) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void print_result(const std::string &label, int new_failures, int new_assertions, const std::string &extra = "") const {
|
||||
std::string line = indent() + label;
|
||||
|
||||
std::string details;
|
||||
if (new_assertions > 0) {
|
||||
if (new_failures == 0) {
|
||||
details = std::to_string(new_assertions) + " assertion(s)";
|
||||
} else {
|
||||
details = std::to_string(new_failures) + " of " +
|
||||
std::to_string(new_assertions) + " assertion(s) failed";
|
||||
}
|
||||
}
|
||||
if (!extra.empty()) {
|
||||
if (!details.empty()) {
|
||||
details += ", ";
|
||||
}
|
||||
details += extra;
|
||||
}
|
||||
|
||||
if (!details.empty()) {
|
||||
line += " (" + details + ")";
|
||||
}
|
||||
|
||||
std::string status = (new_failures == 0) ? "[PASS]" : "[FAIL]";
|
||||
|
||||
if (line.size() + 1 < status_column) {
|
||||
line.append(status_column - line.size(), ' ');
|
||||
} else {
|
||||
line.push_back(' ');
|
||||
}
|
||||
|
||||
out << line << status << "\n";
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void test(const std::string &name, F f) {
|
||||
stack.push_back(name);
|
||||
if (!should_run()) {
|
||||
stack.pop_back();
|
||||
return;
|
||||
}
|
||||
|
||||
++tests;
|
||||
out << indent() << name << "\n";
|
||||
|
||||
int before_failures = failures;
|
||||
int before_assertions = assertions;
|
||||
|
||||
run_with_exceptions([&] { f(*this); }, "test");
|
||||
|
||||
int new_failures = failures - before_failures;
|
||||
int new_assertions = assertions - before_assertions;
|
||||
|
||||
print_result(name, new_failures, new_assertions);
|
||||
|
||||
stack.pop_back();
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void test(F f) {
|
||||
test("test #" + std::to_string(++unnamed), f);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void bench(const std::string &name, F f, int iterations = 100) {
|
||||
stack.push_back(name);
|
||||
if (!should_run()) {
|
||||
stack.pop_back();
|
||||
return;
|
||||
}
|
||||
|
||||
++tests;
|
||||
out << indent() << "[bench] " << name << "\n";
|
||||
|
||||
int before_failures = failures;
|
||||
int before_assertions = assertions;
|
||||
|
||||
using clock = std::chrono::high_resolution_clock;
|
||||
|
||||
std::chrono::microseconds duration(0);
|
||||
|
||||
run_with_exceptions([&] {
|
||||
for (auto i = 0; i < iterations; i++) {
|
||||
auto start = clock::now();
|
||||
f();
|
||||
duration += std::chrono::duration_cast<std::chrono::microseconds>(clock::now() - start);
|
||||
}
|
||||
}, "bench");
|
||||
|
||||
auto avg_elapsed = duration.count() / iterations;
|
||||
auto avg_elapsed_s = std::chrono::duration_cast<std::chrono::duration<double>>(duration).count() / iterations;
|
||||
auto rate = (avg_elapsed_s > 0.0) ? (1.0 / avg_elapsed_s) : 0.0;
|
||||
|
||||
int new_failures = failures - before_failures;
|
||||
int new_assertions = assertions - before_assertions;
|
||||
|
||||
std::string extra =
|
||||
"n=" + std::to_string(iterations) +
|
||||
" avg=" + std::to_string(avg_elapsed) + "us" +
|
||||
" rate=" + std::to_string(int(rate)) + "/s";
|
||||
|
||||
print_result("[bench] " + name, new_failures, new_assertions, extra);
|
||||
|
||||
stack.pop_back();
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void bench(F f, int iterations = 100) {
|
||||
bench("bench #" + std::to_string(++unnamed), f, iterations);
|
||||
}
|
||||
|
||||
// Assertions
|
||||
bool assert_true(bool cond) {
|
||||
return assert_true("", cond);
|
||||
}
|
||||
|
||||
bool assert_true(const std::string &msg, bool cond) {
|
||||
++assertions;
|
||||
if (!cond) {
|
||||
++failures;
|
||||
out << indent() << "ASSERT TRUE FAILED";
|
||||
if (!msg.empty()) {
|
||||
out << " : " << msg;
|
||||
}
|
||||
out << "\n";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
bool assert_equal(const A &expected, const B &actual) {
|
||||
return assert_equal("", expected, actual);
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
bool assert_equal(const std::string &msg, const A &expected, const B &actual) {
|
||||
++assertions;
|
||||
if (!(actual == expected)) {
|
||||
++failures;
|
||||
out << indent() << "ASSERT EQUAL FAILED";
|
||||
if (!msg.empty()) {
|
||||
out << " : " << msg;
|
||||
}
|
||||
out << "\n";
|
||||
|
||||
out << indent() << " expected: " << expected << "\n";
|
||||
out << indent() << " actual : " << actual << "\n";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Print summary and return an exit code
|
||||
int summary() const {
|
||||
out << "\n";
|
||||
out << "tests : " << tests << "\n";
|
||||
out << "assertions : " << assertions << "\n";
|
||||
out << "failures : " << failures << "\n";
|
||||
out << "exceptions : " << exceptions << "\n";
|
||||
return failures == 0 ? 0 : 1;
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
// Common includes for all test files
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "testing.h"
|
||||
#include "peg-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "simple-tokenize.h"
|
||||
|
||||
struct bench_tool_call {
|
||||
std::string id;
|
||||
std::string name;
|
||||
nlohmann::ordered_json args;
|
||||
};
|
||||
|
||||
// Test function declarations
|
||||
void test_basic(testing &t);
|
||||
void test_json_parser(testing &t);
|
||||
void test_gbnf_generation(testing &t);
|
||||
void test_unicode(testing &t);
|
||||
void test_json_serialization(testing &t);
|
||||
+38
-22
@@ -41,12 +41,18 @@
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# define N_THREADS 1
|
||||
#else
|
||||
# define N_THREADS std::thread::hardware_concurrency()
|
||||
#endif
|
||||
|
||||
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
|
||||
size_t nels = ggml_nelements(tensor);
|
||||
std::vector<float> data(nels);
|
||||
{
|
||||
// parallel initialization
|
||||
static const size_t n_threads = std::thread::hardware_concurrency();
|
||||
static const size_t n_threads = N_THREADS;
|
||||
// static RNG initialization (revisit if n_threads stops being constant)
|
||||
static std::vector<std::default_random_engine> generators = []() {
|
||||
std::random_device rd;
|
||||
@@ -65,15 +71,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_threads);
|
||||
for (size_t i = 0; i < n_threads; i++) {
|
||||
size_t start = i*nels/n_threads;
|
||||
size_t end = (i+1)*nels/n_threads;
|
||||
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
if (n_threads == 1) {
|
||||
init_thread(0, 0, nels);
|
||||
} else {
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_threads);
|
||||
for (size_t i = 0; i < n_threads; i++) {
|
||||
size_t start = i*nels/n_threads;
|
||||
size_t end = (i+1)*nels/n_threads;
|
||||
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,17 +115,23 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
};
|
||||
|
||||
const size_t min_blocks_per_thread = 1;
|
||||
const size_t n_threads = std::min<size_t>(std::thread::hardware_concurrency()/2,
|
||||
std::max<size_t>(1, n_blocks / min_blocks_per_thread));
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_threads);
|
||||
for (size_t i = 0; i < n_threads; i++) {
|
||||
size_t start = i*n_blocks/n_threads;
|
||||
size_t end = (i+1)*n_blocks/n_threads;
|
||||
tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
const size_t n_quant_threads = std::min<size_t>(std::max<size_t>(N_THREADS/2, 1),
|
||||
std::max<size_t>(1, n_blocks / min_blocks_per_thread));
|
||||
|
||||
if (n_quant_threads == 1) {
|
||||
// single-threaded quantization: do all blocks in the current thread
|
||||
quantize_thread(0, n_blocks);
|
||||
} else {
|
||||
std::vector<std::future<void>> tasks;
|
||||
tasks.reserve(n_quant_threads);
|
||||
for (size_t i = 0; i < n_quant_threads; i++) {
|
||||
size_t start = i*n_blocks/n_quant_threads;
|
||||
size_t end = (i+1)*n_blocks/n_quant_threads;
|
||||
tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
|
||||
}
|
||||
for (auto & t : tasks) {
|
||||
t.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
|
||||
@@ -8363,7 +8379,7 @@ int main(int argc, char ** argv) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
// TODO: better value for n_threads
|
||||
ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
|
||||
ggml_backend_set_n_threads_fn(backend, N_THREADS);
|
||||
}
|
||||
|
||||
size_t free, total; // NOLINT
|
||||
|
||||
@@ -0,0 +1,768 @@
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "chat-parser.h"
|
||||
#include "chat-peg-parser.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "peg-parser.h"
|
||||
#include "peg-parser/testing.h"
|
||||
#include "peg-parser/simple-tokenize.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static json create_tools();
|
||||
static void test_example_native(testing & t);
|
||||
static void test_example_qwen3_coder(testing & t);
|
||||
static void test_command7_parser_compare(testing & t);
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
testing t(std::cout);
|
||||
if (argc >= 2) {
|
||||
t.set_filter(argv[1]);
|
||||
}
|
||||
|
||||
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
|
||||
if (verbose) {
|
||||
t.verbose = std::string(verbose) == "1";
|
||||
}
|
||||
|
||||
t.test("native", test_example_native);
|
||||
t.test("qwen3 coder", test_example_qwen3_coder);
|
||||
t.test("comparison", test_command7_parser_compare);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
|
||||
static json create_tools() {
|
||||
json tools = json::array();
|
||||
|
||||
json tool_weather = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "get_current_weather"},
|
||||
{"description", "Get the current weather in a given location"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"location", {
|
||||
{"type", "string"},
|
||||
{"description", "The city and state, e.g. San Francisco, CA"}
|
||||
}},
|
||||
{"unit", {
|
||||
{"type", "string"},
|
||||
{"enum", {"celsius", "fahrenheit"}},
|
||||
{"description", "The temperature unit to use. Infer this from the users location."}
|
||||
}}
|
||||
}},
|
||||
{"required", {"location", "unit"}},
|
||||
}},
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_weather);
|
||||
|
||||
json tool_forecast = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "get_forecast"},
|
||||
{"description", "Get the weather forecast for a given location"},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"location", {
|
||||
{"type", "string"},
|
||||
{"description", "The city and state, e.g. San Francisco, CA"}
|
||||
}},
|
||||
{"unit", {
|
||||
{"type", "string"},
|
||||
{"enum", {"celsius", "fahrenheit"}},
|
||||
{"description", "The temperature unit to use. Infer this from the users location."}
|
||||
}},
|
||||
{"days", {
|
||||
{"type", "integer"},
|
||||
{"description", "Number of days to forecast (1-10)"},
|
||||
{"minimum", 1},
|
||||
{"maximum", 10}
|
||||
}}
|
||||
}},
|
||||
{"required", {"location", "unit"}},
|
||||
}},
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_forecast);
|
||||
|
||||
json tool_search = {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "search_knowledge_base"},
|
||||
{"description", "Search the internal technical documentation knowledge base."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"query", {
|
||||
{"type", "string"},
|
||||
{"description", "The search query string."}
|
||||
}},
|
||||
{"max_results", {
|
||||
{"type", "integer"},
|
||||
{"description", "The maximum number of results to return."},
|
||||
{"default", 5}
|
||||
}},
|
||||
{"category", {
|
||||
{"type", "string"},
|
||||
{"enum", {"api", "troubleshooting", "billing", "general"}},
|
||||
{"description", "Filter search by specific category."}
|
||||
}}
|
||||
}},
|
||||
{"required", {"query", "category"}},
|
||||
{"additionalProperties", false}
|
||||
}},
|
||||
{"strict", true}
|
||||
}}
|
||||
};
|
||||
tools.push_back(tool_search);
|
||||
|
||||
return tools;
|
||||
}
|
||||
|
||||
struct tool_argument {
|
||||
std::string name;
|
||||
std::string type;
|
||||
bool is_required;
|
||||
json schema;
|
||||
};
|
||||
|
||||
struct tool_definition {
|
||||
std::string name;
|
||||
std::vector<tool_argument> arguments;
|
||||
json schema;
|
||||
};
|
||||
|
||||
// Test fictitious model output that emits arguments as JSON.
|
||||
static void test_example_native(testing & t) {
|
||||
struct test_case {
|
||||
// Parameters
|
||||
std::string name;
|
||||
json tools;
|
||||
common_chat_tool_choice tool_choice;
|
||||
common_reasoning_format reasoning_format;
|
||||
json json_schema;
|
||||
bool parallel_tool_calls;
|
||||
bool thinking_forced_open;
|
||||
std::string input;
|
||||
|
||||
// Expect
|
||||
std::string expect_reasoning;
|
||||
std::string expect_content;
|
||||
std::vector<common_chat_tool_call> expect_tool_calls;
|
||||
};
|
||||
|
||||
auto build_parser = [](const test_case & tc) {
|
||||
return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE);
|
||||
auto reasoning = p.eps();
|
||||
if (tc.thinking_forced_open) {
|
||||
// If thinking is forced open, expect a closing tag
|
||||
reasoning = p.reasoning(p.until("</think>")) + "</think>" + p.space();
|
||||
} else {
|
||||
// Otherwise, optionally accept thinking wrapped in tags
|
||||
reasoning = p.optional("<think>" + p.reasoning(p.until("</think>")) + "</think>" + p.space());
|
||||
}
|
||||
|
||||
// tool calling parser
|
||||
if (tc.tools.is_array() && !tc.tools.empty()) {
|
||||
auto tools = p.choice();
|
||||
for (const auto & tool : tc.tools) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\"");
|
||||
auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
|
||||
|
||||
tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}");
|
||||
};
|
||||
|
||||
auto parallel_calls = p.eps();
|
||||
if (tc.parallel_tool_calls) {
|
||||
parallel_calls = p.zero_or_more("," << tools);
|
||||
}
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
p.sequence({
|
||||
p.literal("<tool_call>["),
|
||||
tools,
|
||||
parallel_calls,
|
||||
p.literal("]</tool_call>")
|
||||
})
|
||||
);
|
||||
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.until("<tool_call>")),
|
||||
p.optional(p.space() + tool_call),
|
||||
p.space(),
|
||||
p.end()
|
||||
});
|
||||
}
|
||||
|
||||
// response_format parser
|
||||
if (tc.json_schema.is_object() && !tc.json_schema.empty()) {
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.schema(p.json(), "response-output", tc.json_schema)),
|
||||
p.space(),
|
||||
p.end()
|
||||
});
|
||||
}
|
||||
|
||||
// Content-only parser
|
||||
return p.sequence({
|
||||
(reasoning_in_content ? p.eps() : reasoning),
|
||||
p.content(p.rest()),
|
||||
p.end()
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
std::vector<test_case> test_cases = std::vector<test_case>{
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .input = */ (
|
||||
"<think>The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "The user said hello, I must say hello back",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false and no reasoning",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .input = */ (
|
||||
"Hello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = false and reasoning_format = none",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"<think>The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "<think>The user said hello, I must say hello back</think>\nHello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = true",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "The user said hello, I must say hello back",
|
||||
/* .expect_content = */ "Hello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "content with thinking_forced_open = true and reasoning_format = none",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"The user said hello, I must say hello back</think>\nHello"
|
||||
),
|
||||
/* .expect_reasoning = */ "",
|
||||
/* .expect_content = */ "The user said hello, I must say hello back</think>\nHello",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
{
|
||||
/* .name = */ "tools with tool_choice = auto and no parallel_tool_calls",
|
||||
/* .tools = */ create_tools(),
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must get the weather in New York</think>\n"
|
||||
"<tool_call>["
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
|
||||
"]</tool_call>"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must get the weather in New York",
|
||||
/* .expect_content = */ "",
|
||||
/* .expect_tool_calls = */ {{
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}},
|
||||
},
|
||||
{
|
||||
/* .name = */ "tools with tool_choice = auto and parallel_tool_calls",
|
||||
/* .tools = */ create_tools(),
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {},
|
||||
/* .parallel_tool_calls = */ true,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must get the weather in New York and San Francisco and a 3 day forecast of each.</think>\nLet me search that for you."
|
||||
"<tool_call>["
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})"
|
||||
", "
|
||||
R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})"
|
||||
", "
|
||||
R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})"
|
||||
", "
|
||||
R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})"
|
||||
"]</tool_call>"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.",
|
||||
/* .expect_content = */ "Let me search that for you.",
|
||||
/* .expect_tool_calls = */ {{
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_current_weather",
|
||||
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_forecast",
|
||||
/* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})",
|
||||
/* .id = */ "",
|
||||
}, {
|
||||
/* .name = */ "get_forecast",
|
||||
/* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})",
|
||||
/* .id = */ "",
|
||||
}},
|
||||
},
|
||||
{
|
||||
/* .name = */ "response_format with thinking_forced_open = true",
|
||||
/* .tools = */ {},
|
||||
/* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .json_schema = */ {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"invoice_number", {{"type", "string"}}},
|
||||
{"amount", {{"type", "number"}}},
|
||||
{"due_date", {{"type", "string"}}}
|
||||
}},
|
||||
{"required", {"invoice_number", "amount", "due_date"}}
|
||||
},
|
||||
/* .parallel_tool_calls = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .input = */ (
|
||||
"I must produce the invoice in the requested format</think>\n"
|
||||
R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"
|
||||
),
|
||||
/* .expect_reasoning = */ "I must produce the invoice in the requested format",
|
||||
/* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})",
|
||||
/* .expect_tool_calls = */ {},
|
||||
},
|
||||
};
|
||||
|
||||
for (const auto & tc : test_cases) {
|
||||
t.test(tc.name, [&](testing & t) {
|
||||
auto parser = build_parser(tc);
|
||||
auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
for (auto const & def : tc.tools) {
|
||||
auto function = def.at("function");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
};
|
||||
parser.build_grammar(builder, lazy);
|
||||
});
|
||||
|
||||
t.log("Grammar:");
|
||||
for (auto const & line : string_split(grammar, "\n")) {
|
||||
t.log(line);
|
||||
}
|
||||
|
||||
common_peg_parse_context ctx(tc.input, false);
|
||||
auto result = parser.parse(ctx);
|
||||
|
||||
t.assert_true("success", result.success());
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
t.assert_equal("content equal", tc.expect_content, msg.content);
|
||||
t.assert_equal("reasoning equal", tc.expect_reasoning, msg.reasoning_content);
|
||||
t.assert_equal("number of tool calls", tc.expect_tool_calls.size(), msg.tool_calls.size());
|
||||
for (auto i = 0u; i < std::min(tc.expect_tool_calls.size(), msg.tool_calls.size()); i++) {
|
||||
t.assert_equal("tool name", tc.expect_tool_calls[i].name, msg.tool_calls[i].name);
|
||||
t.assert_equal("tool args", tc.expect_tool_calls[i].arguments, msg.tool_calls[i].arguments);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void test_example_qwen3_coder(testing & t) {
|
||||
auto tools = create_tools();
|
||||
auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) {
|
||||
auto content = p.rule("content", p.content(p.until("<tool_call>")));
|
||||
|
||||
std::vector<common_peg_parser> tool_parsers;
|
||||
for (auto const & def : tools) {
|
||||
auto function = def.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
auto properties = parameters.at("properties");
|
||||
|
||||
std::set<std::string> required_properties;
|
||||
if (function.contains("required")) {
|
||||
function.at("required").get_to(required_properties);
|
||||
}
|
||||
|
||||
std::vector<common_peg_parser> arg_parsers;
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required_properties.find(param_name) != required_properties.end();
|
||||
auto type = param_schema.value("type", "object");
|
||||
|
||||
auto arg = p.tool_arg(p.sequence({
|
||||
p.tool_arg_open("<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">"),
|
||||
(type == "string" ?
|
||||
p.tool_arg_string_value(
|
||||
p.schema(
|
||||
p.until_one_of({
|
||||
"</parameter>\n<parameter=",
|
||||
"</parameter>\n</function>"
|
||||
}),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema,
|
||||
true
|
||||
)
|
||||
) : p.tool_arg_json_value(
|
||||
p.schema(
|
||||
p.json(),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema",
|
||||
param_schema
|
||||
)
|
||||
)
|
||||
),
|
||||
p.tool_arg_close(
|
||||
"</parameter>\n" +
|
||||
p.peek(p.literal("<parameter=") | p.literal("</function>"))
|
||||
)
|
||||
}));
|
||||
|
||||
arg_parsers.push_back(is_required ?
|
||||
p.rule("tool-" + name + "-arg-" + param_name, arg) :
|
||||
p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg)));
|
||||
}
|
||||
|
||||
tool_parsers.push_back(p.rule("tool-" + name,
|
||||
p.tool_open("<function=" + p.tool_name(p.literal(name)) + ">")
|
||||
<< p.sequence(arg_parsers)
|
||||
<< p.tool_close(p.literal("</function>"))
|
||||
));
|
||||
};
|
||||
|
||||
auto tool_call = p.trigger_rule("tool-call",
|
||||
"<tool_call>"
|
||||
<< p.choice(tool_parsers)
|
||||
<< "</tool_call>"
|
||||
);
|
||||
|
||||
return content + p.zero_or_more(p.space() + tool_call) + p.end();
|
||||
});
|
||||
|
||||
auto grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
for (auto const & def : tools) {
|
||||
auto function = def.at("function");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
};
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
t.log("Grammar:");
|
||||
for (auto const & line : string_split(grammar, "\n")) {
|
||||
t.log(line);
|
||||
}
|
||||
|
||||
t.test("incremental parsing", [&](testing &t) {
|
||||
std::string input =
|
||||
"Let me search the knowledge base for cat pictures."
|
||||
"<tool_call>\n"
|
||||
"<function=search_knowledge_base>\n"
|
||||
"<parameter=query>cat pictures</parameter>\n"
|
||||
"<parameter=category>general</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
|
||||
std::vector<std::string> tokens = simple_tokenize(input);
|
||||
|
||||
common_chat_msg prev;
|
||||
for (auto it = tokens.begin(); it != tokens.end(); it++) {
|
||||
std::string in = std::accumulate(tokens.begin(), it + 1, std::string());
|
||||
|
||||
common_peg_parse_context ctx(in, it + 1 < tokens.end());
|
||||
|
||||
auto result = parser.parse(ctx);
|
||||
if (!t.assert_equal("not fail", false, result.fail())) {
|
||||
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
|
||||
}
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_constructed_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
//t.log("Input: " + input);
|
||||
t.log("===========================================");
|
||||
t.log("Iteration " + std::to_string(in.size()));
|
||||
t.log("Reasoning: " + msg.reasoning_content);
|
||||
t.log("Content : " + msg.content);
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
t.log("Tool name: " + tc.name);
|
||||
t.log("Tool args: " + tc.arguments);
|
||||
}
|
||||
|
||||
try {
|
||||
// This shouldn't emit any runtime errors
|
||||
auto diffs = common_chat_msg_diff::compute_diffs(prev, msg);
|
||||
} catch(const std::exception & e) {
|
||||
t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end));
|
||||
t.assert_true(std::string("failed with ") + e.what(), false);
|
||||
}
|
||||
|
||||
prev = msg;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void test_command7_parser_compare(testing & t) {
|
||||
auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) {
|
||||
auto thinking = p.reasoning_block(
|
||||
"<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>");
|
||||
|
||||
auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>";
|
||||
|
||||
auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\"")));
|
||||
auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\"")));
|
||||
auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json()));
|
||||
|
||||
auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args);
|
||||
auto tool_call = p.rule("tool-call", p.tool(
|
||||
p.tool_open(p.literal("{"))
|
||||
<< tool_call_fields
|
||||
<< p.zero_or_more( p.literal(",") << tool_call_fields)
|
||||
<< p.tool_close(p.literal("}"))
|
||||
));
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
"<|START_ACTION|>"
|
||||
<< ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]")
|
||||
<< "<|END_ACTION|>");
|
||||
|
||||
return p.optional(thinking) << (tool_calls | response) + p.end();
|
||||
});
|
||||
|
||||
auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) {
|
||||
common_peg_parse_context ctx(input, is_partial);
|
||||
auto result = p.parse(ctx);
|
||||
|
||||
common_chat_msg msg;
|
||||
auto mapper = common_chat_peg_native_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
|
||||
if (print_results) {
|
||||
std::cout << "== Parsed (new) ==\n";
|
||||
std::cout << "=== Reasoning ===\n";
|
||||
std::cout << msg.reasoning_content << "\n";
|
||||
std::cout << "\n\n=== Content ===\n";
|
||||
std::cout << msg.content << "\n";
|
||||
std::cout << "\n\n=== Tool Calls ===\n";
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
std::cout << "id: " << tc.id << "\n";
|
||||
std::cout << "name: " << tc.name << "\n";
|
||||
std::cout << "args: " << tc.arguments << "\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) {
|
||||
// Original common_chat_combinator_parser taken from chat.cpp
|
||||
common_chat_msg_parser builder(
|
||||
input,
|
||||
/* .is_partial = */ need_more_input,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GENERIC,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
}
|
||||
);
|
||||
|
||||
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
|
||||
|
||||
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
|
||||
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
|
||||
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
|
||||
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
|
||||
|
||||
if (auto res = builder.try_find_regex(start_action_regex)) {
|
||||
// If we didn't extract thoughts, prelude includes them.
|
||||
auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } });
|
||||
for (const auto & tool_call : tool_calls.value) {
|
||||
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
|
||||
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
|
||||
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
|
||||
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
}
|
||||
if (tool_calls.is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
builder.consume_regex(end_action_regex);
|
||||
} else if (auto res = builder.try_find_regex(start_response_regex)) {
|
||||
if (!builder.try_find_regex(end_response_regex)) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
throw common_chat_msg_partial_exception(end_response_regex.str());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
if (print_results) {
|
||||
std::cout << "== Parsed (legacy) ==\n";
|
||||
std::cout << "=== Reasoning ===\n";
|
||||
std::cout << builder.result().reasoning_content << "\n";
|
||||
std::cout << "\n\n=== Content ===\n";
|
||||
std::cout << builder.result().content << "\n";
|
||||
std::cout << "\n\n=== Tool Calls ===\n";
|
||||
for (const auto & tc : builder.result().tool_calls) {
|
||||
std::cout << "id: " << tc.id << "\n";
|
||||
std::cout << "name: " << tc.name << "\n";
|
||||
std::cout << "args: " << tc.arguments << "\n";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a "
|
||||
"budget of $4000 for a two-week stay, we need to:\n\n"
|
||||
"1. Identify key historical sites and modern attractions in Japan.\n"
|
||||
"2. Find affordable accommodation options that provide a balance between comfort and cost.\n"
|
||||
"3. Determine the best modes of transportation for getting around Japan.\n"
|
||||
"4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without "
|
||||
"overspending.\n"
|
||||
"5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees "
|
||||
"to attractions.";
|
||||
|
||||
std::vector<std::tuple<std::string, std::string, nlohmann::json>> tool_calls = {{
|
||||
"call_0",
|
||||
"plan_trip",
|
||||
nlohmann::json::parse(R"({
|
||||
"destination": "Japan",
|
||||
"duration": 14,
|
||||
"budget": 4000,
|
||||
"interests": ["historical sites", "modern attractions"],
|
||||
"accommodation_preferences": "affordable",
|
||||
"transportation_preferences": "efficient",
|
||||
"meal_preferences": "local cuisine"
|
||||
})")
|
||||
}};
|
||||
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
// Build tokens
|
||||
if (!reasoning.empty()) {
|
||||
auto tokenized = simple_tokenize(reasoning);
|
||||
tokens.emplace_back("<|START_THINKING|>");
|
||||
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
|
||||
tokens.emplace_back("<|END_THINKING|>");
|
||||
}
|
||||
|
||||
if (!tool_calls.empty()) {
|
||||
tokens.emplace_back("<|START_ACTION|>");
|
||||
|
||||
auto json = nlohmann::json::array();
|
||||
for (const auto & tc : tool_calls) {
|
||||
auto tc_json = nlohmann::json::object();
|
||||
tc_json["tool_call_id"] = std::get<0>(tc);
|
||||
tc_json["tool_name"] = std::get<1>(tc);
|
||||
tc_json["parameters"] = std::get<2>(tc);
|
||||
json.push_back(tc_json);
|
||||
}
|
||||
|
||||
auto tokenized = simple_tokenize(json.dump(-1, ' ', true));
|
||||
tokens.insert(tokens.end(), tokenized.begin(), tokenized.end());
|
||||
|
||||
tokens.emplace_back("<|END_ACTION|>");
|
||||
}
|
||||
|
||||
std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string());
|
||||
|
||||
// Run tests
|
||||
t.test("legacy_parse", [&](testing & /* t */) {
|
||||
test_legacy(input, false, false);
|
||||
});
|
||||
|
||||
t.test("current_parse", [&](testing & /* t */) {
|
||||
test_current(parser, input, false, false);
|
||||
});
|
||||
|
||||
// Run benchmarks
|
||||
t.bench("legacy_parse_benchmark complete", [&]() {
|
||||
test_legacy(input, false, false);
|
||||
});
|
||||
|
||||
t.bench("legacy_parse_benchmark incremental", [&]() {
|
||||
std::string in;
|
||||
for (auto i = 0u; i < tokens.size(); i++) {
|
||||
in += tokens[i];
|
||||
|
||||
try {
|
||||
test_legacy(in, i + 1 < tokens.size(), false);
|
||||
} catch (common_chat_msg_partial_exception & /* e */) {
|
||||
// Do nothing, this is expected
|
||||
}
|
||||
}
|
||||
}, 20);
|
||||
|
||||
t.bench("current_parse_benchmark complete", [&]() {
|
||||
test_current(parser, input, false, false);
|
||||
}, 100);
|
||||
|
||||
t.bench("current_parse_benchmark incremental", [&]() {
|
||||
std::string in;
|
||||
for (auto i = 0u; i < tokens.size(); i++) {
|
||||
in += tokens[i];
|
||||
test_current(parser, in, i + 1 < tokens.size(), false);
|
||||
}
|
||||
}, 20);
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "peg-parser/tests.h"
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
testing t(std::cout);
|
||||
if (argc >= 2) {
|
||||
t.set_filter(argv[1]);
|
||||
}
|
||||
|
||||
const char * verbose = getenv("LLAMA_TEST_VERBOSE");
|
||||
if (verbose) {
|
||||
t.verbose = std::string(verbose) == "1";
|
||||
}
|
||||
|
||||
t.test("basic", test_basic);
|
||||
t.test("unicode", test_unicode);
|
||||
t.test("json", test_json_parser);
|
||||
t.test("gbnf", test_gbnf_generation);
|
||||
t.test("serialization", test_json_serialization);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
@@ -791,7 +791,7 @@ static void handle_media(
|
||||
SRV_INF("downloading image from '%s'\n", url.c_str());
|
||||
auto res = common_remote_get_content(url, params);
|
||||
if (200 <= res.first && res.first < 300) {
|
||||
SRV_INF("downloaded %ld bytes\n", res.second.size());
|
||||
SRV_INF("downloaded %zu bytes\n", res.second.size());
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), res.second.begin(), res.second.end());
|
||||
out_files.push_back(data);
|
||||
@@ -1045,6 +1045,9 @@ json oaicompat_chat_params_parse(
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
if (!chat_params.parser.empty()) {
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
@@ -889,6 +890,28 @@ struct pipe_t {
|
||||
}
|
||||
};
|
||||
|
||||
static std::string to_lower_copy(const std::string & value) {
|
||||
std::string lowered(value.size(), '\0');
|
||||
std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
|
||||
return lowered;
|
||||
}
|
||||
|
||||
static bool should_strip_proxy_header(const std::string & header_name) {
|
||||
// Headers that get duplicated when router forwards child responses
|
||||
if (header_name == "server" ||
|
||||
header_name == "transfer-encoding" ||
|
||||
header_name == "keep-alive") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Router injects CORS, child also sends them: duplicate
|
||||
if (header_name.rfind("access-control-", 0) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
server_http_proxy::server_http_proxy(
|
||||
const std::string & method,
|
||||
const std::string & host,
|
||||
@@ -925,6 +948,14 @@ server_http_proxy::server_http_proxy(
|
||||
msg_t msg;
|
||||
msg.status = response.status;
|
||||
for (const auto & [key, value] : response.headers) {
|
||||
const auto lowered = to_lower_copy(key);
|
||||
if (should_strip_proxy_header(lowered)) {
|
||||
continue;
|
||||
}
|
||||
if (lowered == "content-type") {
|
||||
msg.content_type = value;
|
||||
continue;
|
||||
}
|
||||
msg.headers[key] = value;
|
||||
}
|
||||
return pipe->write(std::move(msg)); // send headers first
|
||||
@@ -932,7 +963,7 @@ server_http_proxy::server_http_proxy(
|
||||
httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
|
||||
// send data chunks
|
||||
// returns false if pipe is closed / broken (signal to stop receiving)
|
||||
return pipe->write({{}, 0, std::string(data, data_length)});
|
||||
return pipe->write({{}, 0, std::string(data, data_length), ""});
|
||||
};
|
||||
|
||||
// prepare the request to destination server
|
||||
@@ -955,8 +986,8 @@ server_http_proxy::server_http_proxy(
|
||||
if (result.error() != httplib::Error::Success) {
|
||||
auto err_str = httplib::to_string(result.error());
|
||||
SRV_ERR("http client error: %s\n", err_str.c_str());
|
||||
pipe->write({{}, 500, ""}); // header
|
||||
pipe->write({{}, 0, "proxy error: " + err_str}); // body
|
||||
pipe->write({{}, 500, "", ""}); // header
|
||||
pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
|
||||
}
|
||||
pipe->close_write(); // signal EOF to reader
|
||||
SRV_DBG("%s", "client request thread ended\n");
|
||||
@@ -964,12 +995,17 @@ server_http_proxy::server_http_proxy(
|
||||
this->thread.detach();
|
||||
|
||||
// wait for the first chunk (headers)
|
||||
msg_t header;
|
||||
if (pipe->read(header, should_stop)) {
|
||||
SRV_DBG("%s", "received response headers\n");
|
||||
this->status = header.status;
|
||||
this->headers = header.headers;
|
||||
} else {
|
||||
SRV_DBG("%s", "no response headers received (request cancelled?)\n");
|
||||
{
|
||||
msg_t header;
|
||||
if (pipe->read(header, should_stop)) {
|
||||
SRV_DBG("%s", "received response headers\n");
|
||||
this->status = header.status;
|
||||
this->headers = std::move(header.headers);
|
||||
if (!header.content_type.empty()) {
|
||||
this->content_type = std::move(header.content_type);
|
||||
}
|
||||
} else {
|
||||
SRV_DBG("%s", "no response headers received (request cancelled?)\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,5 +170,6 @@ private:
|
||||
std::map<std::string, std::string> headers;
|
||||
int status = 0;
|
||||
std::string data;
|
||||
std::string content_type;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -297,6 +297,9 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
|
||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
|
||||
if (data.contains("chat_parser")) {
|
||||
params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
Vendored
+3
@@ -144,4 +144,7 @@ if (CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
||||
endif()
|
||||
if (WIN32 AND NOT MSVC)
|
||||
target_link_libraries(${TARGET} PUBLIC crypt32)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
Reference in New Issue
Block a user