feat: implement Ollama HTTP client for Model Gateway (issue #39)
Add async HTTP client wrapping the Ollama REST API with: - OllamaClient with generate, generate_stream, chat, embed, list_models, is_healthy - NDJSON streaming parser for /api/generate streaming responses - Serde types for all Ollama API endpoints - OllamaError enum with Http, Api, Deserialization, StreamIncomplete variants - OllamaClientConfig for timeout and connection pool settings - Integration into ModelGatewayServiceImpl (constructor now returns Result) - 48 tests (types serde, wiremock HTTP mocks, error handling, config) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
290
Cargo.lock
generated
290
Cargo.lock
generated
@@ -230,6 +230,16 @@ dependencies = [
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assert-json-diff"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
@@ -490,6 +500,26 @@ dependencies = [
|
||||
"tiny-keccak",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.7"
|
||||
@@ -530,6 +560,24 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deadpool"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
|
||||
dependencies = [
|
||||
"deadpool-runtime",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deadpool-runtime"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
|
||||
|
||||
[[package]]
|
||||
name = "derive_arbitrary"
|
||||
version = "1.4.2"
|
||||
@@ -585,6 +633,15 @@ version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "encoding_rs"
|
||||
version = "0.8.35"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.2"
|
||||
@@ -665,6 +722,21 @@ version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
|
||||
dependencies = [
|
||||
"foreign-types-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types-shared"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.2.2"
|
||||
@@ -888,6 +960,12 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.4.0"
|
||||
@@ -986,6 +1064,22 @@ dependencies = [
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tls"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.20"
|
||||
@@ -1004,9 +1098,11 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2 0.6.3",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"windows-registry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1455,10 +1551,14 @@ name = "model-gateway"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"futures",
|
||||
"llm-multiverse-proto",
|
||||
"prost",
|
||||
"prost-types",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
@@ -1467,6 +1567,7 @@ dependencies = [
|
||||
"tonic",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1475,6 +1576,23 @@ version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"openssl",
|
||||
"openssl-probe",
|
||||
"openssl-sys",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.50.3"
|
||||
@@ -1558,12 +1676,66 @@ dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.75"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cfg-if",
|
||||
"foreign-types",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"openssl-macros",
|
||||
"openssl-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.111"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.5"
|
||||
@@ -1970,17 +2142,22 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-tls",
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"native-tls",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"quinn",
|
||||
@@ -1991,13 +2168,16 @@ dependencies = [
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"webpki-roots",
|
||||
]
|
||||
@@ -2127,6 +2307,15 @@ version = "1.0.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
@@ -2161,6 +2350,29 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "3.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"core-foundation 0.10.1",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework-sys"
|
||||
version = "2.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
@@ -2405,6 +2617,27 @@ dependencies = [
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"core-foundation 0.9.4",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration-sys"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tap"
|
||||
version = "1.0.1"
|
||||
@@ -2526,6 +2759,16 @@ dependencies = [
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-native-tls"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
|
||||
dependencies = [
|
||||
"native-tls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.4"
|
||||
@@ -2980,6 +3223,19 @@ dependencies = [
|
||||
"wasmparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-streams"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasmparser"
|
||||
version = "0.244.0"
|
||||
@@ -3062,6 +3318,17 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "windows-registry"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-result"
|
||||
version = "0.4.1"
|
||||
@@ -3171,6 +3438,29 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wiremock"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
|
||||
dependencies = [
|
||||
"assert-json-diff",
|
||||
"base64",
|
||||
"deadpool",
|
||||
"futures",
|
||||
"http",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.51.0"
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
| #36 | Implement GetCorrelated gRPC endpoint | Phase 4 | `COMPLETED` | Rust | [issue-036.md](issue-036.md) |
|
||||
| #37 | Integration tests for Memory Service | Phase 4 | `COMPLETED` | Rust | [issue-037.md](issue-037.md) |
|
||||
| #38 | Scaffold Model Gateway Rust project | Phase 5 | `COMPLETED` | Rust | [issue-038.md](issue-038.md) |
|
||||
| #39 | Implement Ollama HTTP client | Phase 5 | `IMPLEMENTING` | Rust | [issue-039.md](issue-039.md) |
|
||||
|
||||
## Status Legend
|
||||
|
||||
@@ -83,6 +84,8 @@
|
||||
|
||||
### Model Gateway
|
||||
- [issue-012.md](issue-012.md) — model_gateway.proto (ModelGatewayService)
|
||||
- [issue-038.md](issue-038.md) — Scaffold Model Gateway Rust project
|
||||
- [issue-039.md](issue-039.md) — Ollama HTTP client (reqwest, streaming, embeddings)
|
||||
|
||||
### Search Service
|
||||
- [issue-013.md](issue-013.md) — search.proto (SearchService)
|
||||
|
||||
631
implementation-plans/issue-039.md
Normal file
631
implementation-plans/issue-039.md
Normal file
@@ -0,0 +1,631 @@
|
||||
# Implementation Plan — Issue #39: Implement Ollama HTTP client
|
||||
|
||||
## Metadata
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Issue | [#39](https://git.shahondin1624.de/llm-multiverse/llm-multiverse/issues/39) |
|
||||
| Title | Implement Ollama HTTP client |
|
||||
| Milestone | Phase 5: Model Gateway |
|
||||
| Labels | |
|
||||
| Status | `IMPLEMENTING` |
|
||||
| Language | Rust |
|
||||
| Related Plans | [issue-012.md](issue-012.md), [issue-038.md](issue-038.md) |
|
||||
| Blocked by | #38 (completed) |
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- [ ] Async HTTP client using reqwest
|
||||
- [ ] Support for /api/generate (streaming and non-streaming)
|
||||
- [ ] Support for /api/chat with message history
|
||||
- [ ] Support for /api/embed (embeddings)
|
||||
- [ ] Connection pooling and timeout configuration
|
||||
- [ ] Error handling for Ollama-specific error responses
|
||||
|
||||
## Architecture Analysis
|
||||
|
||||
### Service Context
|
||||
|
||||
This issue belongs to the **Model Gateway** service (`services/model-gateway/`). The Ollama HTTP client is the core backend layer that the gRPC service handlers (`service.rs`) will call to fulfil `Inference`, `StreamInference`, `GenerateEmbedding`, and `IsModelReady` RPCs.
|
||||
|
||||
The client wraps the Ollama REST API (default `http://localhost:11434`) and exposes typed Rust methods that the gRPC handlers can call directly. The gRPC handlers (issue #40+) will translate proto request/response types to/from the Ollama client types defined here.
|
||||
|
||||
**gRPC endpoints affected (consumers of this client):**
|
||||
- `Inference` — will call `OllamaClient::generate()` (non-streaming)
|
||||
- `StreamInference` — will call `OllamaClient::generate_stream()` (streaming)
|
||||
- `GenerateEmbedding` — will call `OllamaClient::embed()`
|
||||
- `IsModelReady` — will call `OllamaClient::list_models()` to check actual Ollama availability
|
||||
|
||||
**Proto messages involved:**
|
||||
- `InferenceParams` — carries prompt, model routing hints (TaskComplexity), temperature, top_p, max_tokens, stop_sequences
|
||||
- `InferenceResponse` — text, finish_reason, tokens_used
|
||||
- `StreamInferenceResponse` — token, finish_reason
|
||||
- `GenerateEmbeddingRequest` — text, model
|
||||
- `GenerateEmbeddingResponse` — embedding vector, dimensions
|
||||
|
||||
### Existing Patterns
|
||||
|
||||
- **Config:** `services/model-gateway/src/config.rs` already defines `Config` with `ollama_url: String` (default `http://localhost:11434`) and `ModelRoutingConfig` for model name resolution.
|
||||
- **Service struct:** `services/model-gateway/src/service.rs` defines `ModelGatewayServiceImpl` holding `Config`. The `OllamaClient` will be added here as a field.
|
||||
- **Error types:** Other services use `thiserror` for module-level error enums (e.g., `DbError`, `EmbeddingError`, `ProvenanceError`). The model-gateway `Cargo.toml` already includes `thiserror = "2"`.
|
||||
- **Async runtime:** `tokio` with `features = ["full"]` is already a dependency. `tokio-stream = "0.1"` is also present.
|
||||
- **Serde:** `serde = { version = "1", features = ["derive"] }` is already a dependency for config deserialization.
|
||||
|
||||
### Dependencies
|
||||
|
||||
- **reqwest** (new) — HTTP client with connection pooling, async support, JSON serialization, and streaming response bodies. Features needed: `json`, `stream`.
|
||||
- **futures** (new) — For `Stream` trait and stream combinators (`futures::Stream`, `futures::StreamExt`). Needed to expose streaming generate responses as a `Stream` type.
|
||||
- **serde_json** (new) — For parsing newline-delimited JSON (NDJSON) from Ollama streaming responses. While `reqwest` can deserialize full JSON responses, streaming requires manual line-by-line parsing.
|
||||
- **No proto changes required** — the Ollama client is an internal HTTP layer; the proto definitions are already complete from issue #12.
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### 1. Types & Configuration
|
||||
|
||||
**Add Ollama-specific configuration to `services/model-gateway/src/config.rs`:**
|
||||
|
||||
```rust
|
||||
/// Configuration for the Ollama HTTP client.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OllamaClientConfig {
|
||||
/// Request timeout in seconds (default: 300 — generous for large model inference).
|
||||
#[serde(default = "default_request_timeout_secs")]
|
||||
pub request_timeout_secs: u64,
|
||||
|
||||
/// Connection timeout in seconds (default: 10).
|
||||
#[serde(default = "default_connect_timeout_secs")]
|
||||
pub connect_timeout_secs: u64,
|
||||
|
||||
/// Maximum idle connections in the pool (default: 10).
|
||||
#[serde(default = "default_pool_max_idle")]
|
||||
pub pool_max_idle: usize,
|
||||
|
||||
/// Idle connection timeout in seconds (default: 60).
|
||||
#[serde(default = "default_pool_idle_timeout_secs")]
|
||||
pub pool_idle_timeout_secs: u64,
|
||||
}
|
||||
```
|
||||
|
||||
Add `#[serde(default)] pub client: OllamaClientConfig` field to the existing `Config` struct.
|
||||
|
||||
**Define Ollama API request/response types in `services/model-gateway/src/ollama/types.rs`:**
|
||||
|
||||
These are serde structs matching the Ollama REST API JSON schema.
|
||||
|
||||
```rust
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// --- /api/generate ---
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GenerateRequest {
|
||||
pub model: String,
|
||||
pub prompt: String,
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<GenerateOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GenerateOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub num_predict: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Full response from /api/generate with stream:false.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenerateResponse {
|
||||
pub model: String,
|
||||
pub response: String,
|
||||
pub done: bool,
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub total_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
/// Single chunk from /api/generate with stream:true (NDJSON).
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenerateStreamChunk {
|
||||
pub model: String,
|
||||
pub response: String,
|
||||
pub done: bool,
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
// --- /api/chat ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ChatRole {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: ChatRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<GenerateOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatResponse {
|
||||
pub model: String,
|
||||
pub message: ChatMessage,
|
||||
pub done: bool,
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub total_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
// --- /api/embed ---
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbedRequest {
|
||||
pub model: String,
|
||||
pub input: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbedResponse {
|
||||
pub model: String,
|
||||
pub embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
// --- /api/tags (list models) ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListModelsResponse {
|
||||
pub models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub name: String,
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub size: u64,
|
||||
#[serde(default)]
|
||||
pub digest: Option<String>,
|
||||
}
|
||||
|
||||
// --- /api/show (model details) ---
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ShowModelRequest {
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ShowModelResponse {
|
||||
pub modelfile: Option<String>,
|
||||
pub parameters: Option<String>,
|
||||
pub template: Option<String>,
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Core Logic
|
||||
|
||||
**Create `services/model-gateway/src/ollama/error.rs` — Error types:**
|
||||
|
||||
```rust
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OllamaError {
|
||||
/// HTTP-level error (connection refused, timeout, DNS, TLS, etc.).
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
|
||||
/// Ollama returned a non-2xx status code.
|
||||
#[error("Ollama API error (status {status}): {message}")]
|
||||
Api {
|
||||
status: u16,
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Failed to deserialize Ollama JSON response.
|
||||
#[error("deserialization error: {0}")]
|
||||
Deserialization(String),
|
||||
|
||||
/// Stream terminated unexpectedly without a done:true chunk.
|
||||
#[error("stream ended unexpectedly")]
|
||||
StreamIncomplete,
|
||||
}
|
||||
```
|
||||
|
||||
**Create `services/model-gateway/src/ollama/client.rs` — OllamaClient:**
|
||||
|
||||
```rust
|
||||
use std::time::Duration;
|
||||
use futures::Stream;
|
||||
use reqwest::Client;
|
||||
|
||||
use crate::config::{Config, OllamaClientConfig};
|
||||
use super::error::OllamaError;
|
||||
use super::types::*;
|
||||
|
||||
/// Async HTTP client for the Ollama REST API.
|
||||
///
|
||||
/// Wraps `reqwest::Client` with connection pooling, timeouts, and
|
||||
/// typed methods for each Ollama endpoint.
|
||||
pub struct OllamaClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Create a new client from the service configuration.
|
||||
///
|
||||
/// Configures connection pooling, timeouts, and the base URL
|
||||
/// from `Config.ollama_url` and `Config.client`.
|
||||
pub fn new(config: &Config) -> Result<Self, OllamaError> {
|
||||
let client_config = &config.client;
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(client_config.request_timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(client_config.connect_timeout_secs))
|
||||
.pool_max_idle_per_host(client_config.pool_max_idle)
|
||||
.pool_idle_timeout(Duration::from_secs(client_config.pool_idle_timeout_secs))
|
||||
.build()?;
|
||||
|
||||
let base_url = config.ollama_url.trim_end_matches('/').to_string();
|
||||
|
||||
Ok(Self { client, base_url })
|
||||
}
|
||||
|
||||
/// POST /api/generate (non-streaming).
|
||||
///
|
||||
/// Sends a prompt to the specified model and returns the complete response.
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &str,
|
||||
options: Option<GenerateOptions>,
|
||||
) -> Result<GenerateResponse, OllamaError> {
|
||||
let request = GenerateRequest {
|
||||
model: model.to_string(),
|
||||
prompt: prompt.to_string(),
|
||||
stream: false,
|
||||
options,
|
||||
};
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/api/generate", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
self.handle_error_response(resp)
|
||||
.await?
|
||||
.json::<GenerateResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// POST /api/generate (streaming).
|
||||
///
|
||||
/// Returns a `Stream` of `GenerateStreamChunk` items. Each chunk
|
||||
/// contains a partial token. The final chunk has `done: true`.
|
||||
///
|
||||
/// Ollama streams NDJSON (one JSON object per line). This method
|
||||
/// reads the response body as a byte stream, splits on newlines,
|
||||
/// and deserializes each line.
|
||||
pub async fn generate_stream(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &str,
|
||||
options: Option<GenerateOptions>,
|
||||
) -> Result<
|
||||
impl Stream<Item = Result<GenerateStreamChunk, OllamaError>>,
|
||||
OllamaError,
|
||||
> {
|
||||
let request = GenerateRequest {
|
||||
model: model.to_string(),
|
||||
prompt: prompt.to_string(),
|
||||
stream: true,
|
||||
options,
|
||||
};
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/api/generate", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let resp = self.handle_error_response(resp).await?;
|
||||
Ok(Self::ndjson_stream::<GenerateStreamChunk>(resp))
|
||||
}
|
||||
|
||||
/// POST /api/chat (non-streaming).
|
||||
///
|
||||
/// Sends a chat conversation (message history) to the model.
|
||||
pub async fn chat(
|
||||
&self,
|
||||
model: &str,
|
||||
messages: Vec<ChatMessage>,
|
||||
options: Option<GenerateOptions>,
|
||||
) -> Result<ChatResponse, OllamaError> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options,
|
||||
};
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/api/chat", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
self.handle_error_response(resp)
|
||||
.await?
|
||||
.json::<ChatResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// POST /api/embed.
|
||||
///
|
||||
/// Generates embedding vectors for the given input texts.
|
||||
/// Returns one embedding vector per input string.
|
||||
pub async fn embed(
|
||||
&self,
|
||||
model: &str,
|
||||
input: Vec<String>,
|
||||
) -> Result<EmbedResponse, OllamaError> {
|
||||
let request = EmbedRequest {
|
||||
model: model.to_string(),
|
||||
input,
|
||||
};
|
||||
|
||||
let resp = self.client
|
||||
.post(format!("{}/api/embed", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
self.handle_error_response(resp)
|
||||
.await?
|
||||
.json::<EmbedResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// GET /api/tags.
|
||||
///
|
||||
/// Lists all models available on the Ollama instance.
|
||||
pub async fn list_models(&self) -> Result<ListModelsResponse, OllamaError> {
|
||||
let resp = self.client
|
||||
.get(format!("{}/api/tags", self.base_url))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
self.handle_error_response(resp)
|
||||
.await?
|
||||
.json::<ListModelsResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// Check if Ollama is reachable by hitting GET /api/tags.
|
||||
/// Returns true if the request succeeds, false otherwise.
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.list_models().await.is_ok()
|
||||
}
|
||||
|
||||
/// Parse NDJSON streaming response into a Stream of typed chunks.
|
||||
///
|
||||
/// Ollama streams responses as newline-delimited JSON. Each line
|
||||
/// is a complete JSON object. This method uses `bytes_stream()`
|
||||
/// from reqwest and buffers bytes until a newline is found,
|
||||
/// then deserializes each complete line.
|
||||
fn ndjson_stream<T: serde::de::DeserializeOwned>(
|
||||
resp: reqwest::Response,
|
||||
) -> impl Stream<Item = Result<T, OllamaError>> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let byte_stream = resp.bytes_stream();
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
futures::stream::unfold(
|
||||
(byte_stream, buffer),
|
||||
|(mut stream, mut buf)| async move {
|
||||
// Implementation: accumulate bytes, split on \n,
|
||||
// deserialize each complete line as T.
|
||||
// Return None when stream ends.
|
||||
// ...
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Check response status and extract error message for non-2xx responses.
|
||||
async fn handle_error_response(
|
||||
&self,
|
||||
resp: reqwest::Response,
|
||||
) -> Result<reqwest::Response, OllamaError> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
let status = resp.status().as_u16();
|
||||
let message = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "unknown error".to_string());
|
||||
|
||||
Err(OllamaError::Api { status, message })
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**NDJSON stream implementation detail:**
|
||||
|
||||
The `ndjson_stream` method will use `reqwest::Response::bytes_stream()` (requires the `stream` feature) and `futures::stream::unfold` to:
|
||||
1. Accumulate bytes from the HTTP response body into a buffer.
|
||||
2. On each newline boundary, extract the complete line.
|
||||
3. Deserialize the line as `T` using `serde_json::from_slice`.
|
||||
4. Yield `Ok(T)` or `Err(OllamaError::Deserialization(...))`.
|
||||
5. Return `None` when the byte stream is exhausted.
|
||||
|
||||
This approach handles partial JSON objects that span multiple TCP chunks correctly.
|
||||
|
||||
### 3. gRPC Handler Wiring
|
||||
|
||||
This issue does **not** implement the gRPC handler wiring — that is deferred to subsequent issues. However, the `OllamaClient` must be integrated into `ModelGatewayServiceImpl` so that future handler implementations can use it.
|
||||
|
||||
**Update `services/model-gateway/src/service.rs`:**
|
||||
|
||||
Add `OllamaClient` as a field on `ModelGatewayServiceImpl`:
|
||||
|
||||
```rust
|
||||
use crate::ollama::OllamaClient;
|
||||
|
||||
pub struct ModelGatewayServiceImpl {
|
||||
config: Config,
|
||||
ollama: OllamaClient,
|
||||
}
|
||||
|
||||
impl ModelGatewayServiceImpl {
|
||||
pub fn new(config: Config) -> anyhow::Result<Self> {
|
||||
let ollama = OllamaClient::new(&config)?;
|
||||
Ok(Self { config, ollama })
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note: The constructor changes from infallible to `Result` since `reqwest::Client::builder().build()` can fail. Update `main.rs` accordingly to use `?`.
|
||||
|
||||
**Update `services/model-gateway/src/main.rs`:**
|
||||
|
||||
Change `ModelGatewayServiceImpl::new(config)` to `ModelGatewayServiceImpl::new(config)?`.
|
||||
|
||||
### 4. Service Integration
|
||||
|
||||
No cross-service integration is needed for this issue. The `OllamaClient` is a standalone HTTP client that talks to the local Ollama instance. Integration with gRPC handlers will happen in follow-up issues.
|
||||
|
||||
### 5. Tests
|
||||
|
||||
**Unit tests for serde types in `services/model-gateway/src/ollama/types.rs`:**
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_generate_request_serialization` | `GenerateRequest` serializes to expected JSON with `stream: false` |
|
||||
| `test_generate_request_serialization_with_options` | Options fields are included when `Some`, omitted when `None` |
|
||||
| `test_generate_response_deserialization` | Deserialize a complete Ollama generate response JSON |
|
||||
| `test_generate_response_missing_optional_fields` | Optional fields default to `None` when absent |
|
||||
| `test_generate_stream_chunk_deserialization` | Deserialize a streaming chunk (partial token, `done: false`) |
|
||||
| `test_generate_stream_chunk_final` | Deserialize final chunk with `done: true` and `done_reason` |
|
||||
| `test_chat_request_serialization` | `ChatRequest` with multiple messages serializes correctly |
|
||||
| `test_chat_role_serialization` | `ChatRole` variants serialize as lowercase strings |
|
||||
| `test_chat_response_deserialization` | Deserialize a complete chat response |
|
||||
| `test_embed_request_serialization` | `EmbedRequest` with multiple inputs serializes correctly |
|
||||
| `test_embed_response_deserialization` | Deserialize embedding response with vector data |
|
||||
| `test_list_models_response_deserialization` | Deserialize model listing with multiple models |
|
||||
| `test_model_info_optional_fields` | `ModelInfo` handles missing `digest` gracefully |
|
||||
|
||||
**Unit tests for error handling in `services/model-gateway/src/ollama/error.rs`:**
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_error_display_http` | `OllamaError::Http` formats with reqwest message |
|
||||
| `test_error_display_api` | `OllamaError::Api` includes status code and message |
|
||||
| `test_error_display_deserialization` | `OllamaError::Deserialization` includes detail |
|
||||
|
||||
**Integration-style tests for `OllamaClient` in `services/model-gateway/src/ollama/client.rs`:**
|
||||
|
||||
Use a mock HTTP server (either `mockito` or `wiremock`) to simulate Ollama API responses:
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_generate_success` | Mock `/api/generate` returns valid JSON, verify parsed response |
|
||||
| `test_generate_with_options` | Verify temperature, top_p, num_predict, stop are sent in request body |
|
||||
| `test_generate_stream_success` | Mock returns NDJSON with 3 chunks + final, verify all chunks yielded |
|
||||
| `test_generate_stream_empty_response` | Mock returns single `done: true` chunk |
|
||||
| `test_chat_success` | Mock `/api/chat` returns valid response, verify message parsing |
|
||||
| `test_chat_with_history` | Send multi-message conversation, verify all messages in request body |
|
||||
| `test_embed_success` | Mock `/api/embed` returns embedding vectors, verify dimensions |
|
||||
| `test_embed_multiple_inputs` | Send multiple texts, verify multiple embeddings returned |
|
||||
| `test_list_models_success` | Mock `/api/tags` returns model list |
|
||||
| `test_list_models_empty` | Mock returns empty model list |
|
||||
| `test_is_healthy_success` | Mock `/api/tags` returns 200, `is_healthy()` returns true |
|
||||
| `test_is_healthy_failure` | Mock returns 500, `is_healthy()` returns false |
|
||||
| `test_api_error_404` | Mock returns 404 with error message, verify `OllamaError::Api` |
|
||||
| `test_api_error_500` | Mock returns 500 with error body, verify error extraction |
|
||||
| `test_connection_timeout` | Client configured with very short timeout, verify `OllamaError::Http` |
|
||||
| `test_base_url_trailing_slash` | Config URL with trailing slash is normalized |
|
||||
|
||||
**Mocking strategy:**
|
||||
|
||||
Use `wiremock` crate as a dev-dependency. It provides a `MockServer` that binds to a random port, allowing parallel test execution without port conflicts. Each test creates its own `MockServer`, configures expected requests/responses, then creates an `OllamaClient` pointed at the mock server URL.
|
||||
|
||||
For streaming tests, the mock server returns a response body containing multiple NDJSON lines separated by `\n`.
|
||||
|
||||
**Configuration tests in `services/model-gateway/src/config.rs`:**
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_client_config_defaults` | `OllamaClientConfig::default()` returns expected timeout/pool values |
|
||||
| `test_client_config_from_toml` | Custom client config loads from TOML |
|
||||
| `test_config_with_client_section` | Full `Config` with `[client]` section parses correctly |
|
||||
|
||||
## Files to Create/Modify
|
||||
|
||||
| File | Action | Purpose |
|
||||
|---|---|---|
|
||||
| `services/model-gateway/Cargo.toml` | Modify | Add `reqwest` (with `json`, `stream` features), `futures`, `serde_json` dependencies; add `wiremock` dev-dependency |
|
||||
| `services/model-gateway/src/config.rs` | Modify | Add `OllamaClientConfig` struct with timeout/pool settings; add `client` field to `Config` |
|
||||
| `services/model-gateway/src/ollama/mod.rs` | Create | Module declaration, re-exports of `OllamaClient`, `OllamaError`, and types |
|
||||
| `services/model-gateway/src/ollama/types.rs` | Create | Serde request/response structs for all Ollama API endpoints |
|
||||
| `services/model-gateway/src/ollama/error.rs` | Create | `OllamaError` enum with `Http`, `Api`, `Deserialization`, `StreamIncomplete` variants |
|
||||
| `services/model-gateway/src/ollama/client.rs` | Create | `OllamaClient` struct with `generate`, `generate_stream`, `chat`, `embed`, `list_models`, `is_healthy` methods and NDJSON stream parser |
|
||||
| `services/model-gateway/src/lib.rs` | Modify | Add `pub mod ollama;` |
|
||||
| `services/model-gateway/src/service.rs` | Modify | Add `OllamaClient` field to `ModelGatewayServiceImpl`; change constructor to return `Result` |
|
||||
| `services/model-gateway/src/main.rs` | Modify | Update `ModelGatewayServiceImpl::new(config)` call to handle `Result` with `?` |
|
||||
|
||||
## Risks and Edge Cases
|
||||
|
||||
- **Streaming NDJSON parsing:** Ollama sends newline-delimited JSON. TCP chunks may not align with JSON object boundaries — a single chunk could contain a partial JSON line or multiple lines. The buffer-based `ndjson_stream` implementation must handle both cases. Mitigation: accumulate bytes until `\n` is found, only parse complete lines.
|
||||
- **Large model response times:** Inference on large models (14B+) can take minutes. The default request timeout of 300 seconds should be sufficient, but this is configurable. Streaming mitigates perceived latency by yielding tokens incrementally.
|
||||
- **Ollama API version compatibility:** The `/api/embed` endpoint (with `input` array) was introduced in Ollama 0.1.44+. Older Ollama versions use `/api/embeddings` with a different request shape. Mitigation: target the newer API. Document the minimum Ollama version requirement.
|
||||
- **Connection pool exhaustion:** If many concurrent gRPC requests hit the gateway simultaneously, the reqwest connection pool could be exhausted. Mitigation: `pool_max_idle` is configurable; the default of 10 is reasonable for a single-node setup. Consider adding a semaphore for concurrency limiting in a future issue if needed.
|
||||
- **`wiremock` test isolation:** Each test creates its own `MockServer` on a random port, so tests can run in parallel safely. However, `wiremock` adds to dev-dependency compile time.
|
||||
- **Constructor change breaks existing tests:** Changing `ModelGatewayServiceImpl::new()` from infallible to `Result` will break existing tests in `service.rs`. Mitigation: update the test helper `test_config()` to also construct the `OllamaClient`, or use a test-only constructor that accepts a pre-built client. Alternatively, keep a separate `new_with_client()` constructor for testability and dependency injection.
|
||||
- **reqwest TLS:** The default reqwest build pulls in `rustls` or `native-tls`. Since Ollama runs locally over plain HTTP, TLS is not needed. Consider using `default-features = false` with just the required features to minimize compile time and binary size. However, if a user runs Ollama behind a TLS reverse proxy, TLS support is needed. Mitigation: use default features (includes TLS) for now; optimize later if compile time is a concern.
|
||||
|
||||
## Deviation Log
|
||||
|
||||
_(Filled during implementation if deviations from plan occur)_
|
||||
|
||||
| Deviation | Reason |
|
||||
|---|---|
|
||||
@@ -18,6 +18,11 @@ toml = "0.8"
|
||||
anyhow = "1"
|
||||
thiserror = "2"
|
||||
tokio-stream = "0.1"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
futures = "0.3"
|
||||
serde_json = "1"
|
||||
bytes = "1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
wiremock = "0.6"
|
||||
|
||||
@@ -71,6 +71,53 @@ impl ModelRoutingConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for the Ollama HTTP client.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OllamaClientConfig {
|
||||
/// Request timeout in seconds (default: 300).
|
||||
#[serde(default = "default_request_timeout_secs")]
|
||||
pub request_timeout_secs: u64,
|
||||
|
||||
/// Connection timeout in seconds (default: 10).
|
||||
#[serde(default = "default_connect_timeout_secs")]
|
||||
pub connect_timeout_secs: u64,
|
||||
|
||||
/// Maximum idle connections in the pool (default: 10).
|
||||
#[serde(default = "default_pool_max_idle")]
|
||||
pub pool_max_idle: usize,
|
||||
|
||||
/// Idle connection timeout in seconds (default: 60).
|
||||
#[serde(default = "default_pool_idle_timeout_secs")]
|
||||
pub pool_idle_timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_request_timeout_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
fn default_connect_timeout_secs() -> u64 {
|
||||
10
|
||||
}
|
||||
|
||||
fn default_pool_max_idle() -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn default_pool_idle_timeout_secs() -> u64 {
|
||||
60
|
||||
}
|
||||
|
||||
impl Default for OllamaClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
request_timeout_secs: default_request_timeout_secs(),
|
||||
connect_timeout_secs: default_connect_timeout_secs(),
|
||||
pool_max_idle: default_pool_max_idle(),
|
||||
pool_idle_timeout_secs: default_pool_idle_timeout_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Top-level configuration for the Model Gateway service.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
@@ -92,6 +139,10 @@ pub struct Config {
|
||||
/// Model routing configuration.
|
||||
#[serde(default)]
|
||||
pub routing: ModelRoutingConfig,
|
||||
|
||||
/// Ollama HTTP client configuration.
|
||||
#[serde(default)]
|
||||
pub client: OllamaClientConfig,
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
@@ -114,6 +165,7 @@ impl Default for Config {
|
||||
ollama_url: default_ollama_url(),
|
||||
audit_addr: None,
|
||||
routing: ModelRoutingConfig::default(),
|
||||
client: OllamaClientConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -230,6 +282,48 @@ code = "codellama:7b"
|
||||
assert_eq!(count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_config_defaults() {
|
||||
let cc = OllamaClientConfig::default();
|
||||
assert_eq!(cc.request_timeout_secs, 300);
|
||||
assert_eq!(cc.connect_timeout_secs, 10);
|
||||
assert_eq!(cc.pool_max_idle, 10);
|
||||
assert_eq!(cc.pool_idle_timeout_secs, 60);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_config_from_toml() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let config_path = dir.path().join("gateway.toml");
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host = "0.0.0.0"
|
||||
port = 9999
|
||||
|
||||
[client]
|
||||
request_timeout_secs = 600
|
||||
connect_timeout_secs = 5
|
||||
pool_max_idle = 20
|
||||
pool_idle_timeout_secs = 120
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
|
||||
assert_eq!(config.client.request_timeout_secs, 600);
|
||||
assert_eq!(config.client.connect_timeout_secs, 5);
|
||||
assert_eq!(config.client.pool_max_idle, 20);
|
||||
assert_eq!(config.client.pool_idle_timeout_secs, 120);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_config_uses_defaults_when_omitted() {
|
||||
let config = Config::default();
|
||||
assert_eq!(config.client.request_timeout_secs, 300);
|
||||
assert_eq!(config.client.connect_timeout_secs, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routing_from_toml_uses_defaults_when_omitted() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
pub mod config;
|
||||
pub mod ollama;
|
||||
pub mod service;
|
||||
|
||||
@@ -23,7 +23,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
|
||||
let addr = config.listen_addr().parse()?;
|
||||
let service = ModelGatewayServiceImpl::new(config);
|
||||
let service = ModelGatewayServiceImpl::new(config)?;
|
||||
|
||||
tracing::info!(%addr, "Model Gateway listening");
|
||||
|
||||
|
||||
570
services/model-gateway/src/ollama/client.rs
Normal file
570
services/model-gateway/src/ollama/client.rs
Normal file
@@ -0,0 +1,570 @@
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::stream::Stream;
|
||||
use futures::StreamExt;
|
||||
use reqwest::Client;
|
||||
|
||||
use super::error::OllamaError;
|
||||
use super::types::*;
|
||||
use crate::config::Config;
|
||||
|
||||
/// Async HTTP client for the Ollama REST API.
|
||||
///
|
||||
/// Wraps `reqwest::Client` with connection pooling, timeouts, and
|
||||
/// typed methods for each Ollama endpoint.
|
||||
pub struct OllamaClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Create a new client from the service configuration.
|
||||
pub fn new(config: &Config) -> Result<Self, OllamaError> {
|
||||
let cc = &config.client;
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(cc.request_timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(cc.connect_timeout_secs))
|
||||
.pool_max_idle_per_host(cc.pool_max_idle)
|
||||
.pool_idle_timeout(Duration::from_secs(cc.pool_idle_timeout_secs))
|
||||
.build()?;
|
||||
|
||||
let base_url = config.ollama_url.trim_end_matches('/').to_string();
|
||||
Ok(Self { client, base_url })
|
||||
}
|
||||
|
||||
/// POST /api/generate (non-streaming).
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &str,
|
||||
options: Option<GenerateOptions>,
|
||||
) -> Result<GenerateResponse, OllamaError> {
|
||||
let request = GenerateRequest {
|
||||
model: model.to_string(),
|
||||
prompt: prompt.to_string(),
|
||||
stream: false,
|
||||
options,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/api/generate", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.json::<GenerateResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// POST /api/generate (streaming).
|
||||
///
|
||||
/// Returns a `Stream` of `GenerateStreamChunk` items parsed from NDJSON.
|
||||
pub async fn generate_stream(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &str,
|
||||
options: Option<GenerateOptions>,
|
||||
) -> Result<
|
||||
Pin<Box<dyn Stream<Item = Result<GenerateStreamChunk, OllamaError>> + Send>>,
|
||||
OllamaError,
|
||||
> {
|
||||
let request = GenerateRequest {
|
||||
model: model.to_string(),
|
||||
prompt: prompt.to_string(),
|
||||
stream: true,
|
||||
options,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/api/generate", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let resp = Self::check_status(resp).await?;
|
||||
Ok(Self::ndjson_stream::<GenerateStreamChunk>(resp))
|
||||
}
|
||||
|
||||
/// POST /api/chat (non-streaming).
|
||||
pub async fn chat(
|
||||
&self,
|
||||
model: &str,
|
||||
messages: Vec<ChatMessage>,
|
||||
options: Option<GenerateOptions>,
|
||||
) -> Result<ChatResponse, OllamaError> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/api/chat", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.json::<ChatResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// POST /api/embed.
|
||||
pub async fn embed(
|
||||
&self,
|
||||
model: &str,
|
||||
input: Vec<String>,
|
||||
) -> Result<EmbedResponse, OllamaError> {
|
||||
let request = EmbedRequest {
|
||||
model: model.to_string(),
|
||||
input,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/api/embed", self.base_url))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.json::<EmbedResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// GET /api/tags — list all models available on the Ollama instance.
|
||||
pub async fn list_models(&self) -> Result<ListModelsResponse, OllamaError> {
|
||||
let resp = self
|
||||
.client
|
||||
.get(format!("{}/api/tags", self.base_url))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.json::<ListModelsResponse>()
|
||||
.await
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()))
|
||||
}
|
||||
|
||||
/// Check if Ollama is reachable.
|
||||
pub async fn is_healthy(&self) -> bool {
|
||||
self.list_models().await.is_ok()
|
||||
}
|
||||
|
||||
/// Parse an NDJSON response body into a typed stream.
|
||||
///
|
||||
/// Ollama streams newline-delimited JSON. Each line is a complete
|
||||
/// JSON object. This buffers bytes until newlines are found, then
|
||||
/// deserializes each complete line.
|
||||
fn ndjson_stream<T: serde::de::DeserializeOwned + Send + 'static>(
|
||||
resp: reqwest::Response,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<T, OllamaError>> + Send>> {
|
||||
let byte_stream = resp.bytes_stream();
|
||||
|
||||
type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>;
|
||||
|
||||
Box::pin(futures::stream::unfold(
|
||||
(Box::pin(byte_stream) as ByteStream, Vec::<u8>::new()),
|
||||
|(mut stream, mut buf): (ByteStream, Vec<u8>)| async move {
|
||||
loop {
|
||||
// Check if buffer contains a complete line.
|
||||
if let Some(pos) = buf.iter().position(|&b| b == b'\n') {
|
||||
let line = buf.drain(..=pos).collect::<Vec<_>>();
|
||||
let line = &line[..line.len() - 1]; // strip newline
|
||||
if line.is_empty() {
|
||||
continue; // skip empty lines
|
||||
}
|
||||
let result = serde_json::from_slice::<T>(line)
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()));
|
||||
return Some((result, (stream, buf)));
|
||||
}
|
||||
|
||||
// Need more data from the stream.
|
||||
match stream.next().await {
|
||||
Some(Ok(chunk)) => {
|
||||
buf.extend_from_slice(&chunk);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
return Some((Err(OllamaError::Http(e)), (stream, buf)));
|
||||
}
|
||||
None => {
|
||||
// Stream ended. Parse any remaining data in buffer.
|
||||
if buf.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Trim trailing whitespace
|
||||
let trimmed = buf.trim_ascii();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let result = serde_json::from_slice::<T>(trimmed)
|
||||
.map_err(|e| OllamaError::Deserialization(e.to_string()));
|
||||
buf.clear();
|
||||
return Some((result, (stream, buf)));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Check response status and extract error message for non-2xx.
|
||||
async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, OllamaError> {
|
||||
if resp.status().is_success() {
|
||||
return Ok(resp);
|
||||
}
|
||||
let status = resp.status().as_u16();
|
||||
let message = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "unknown error".to_string());
|
||||
Err(OllamaError::Api { status, message })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
|
||||
fn test_client(base_url: &str) -> OllamaClient {
|
||||
let config = Config {
|
||||
ollama_url: base_url.to_string(),
|
||||
..Config::default()
|
||||
};
|
||||
OllamaClient::new(&config).unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_success() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"model":"llama3.2:3b","response":"Hello world!","done":true,"done_reason":"stop","eval_count":5}"#;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let resp = client.generate("llama3.2:3b", "Hi", None).await.unwrap();
|
||||
assert_eq!(resp.response, "Hello world!");
|
||||
assert!(resp.done);
|
||||
assert_eq!(resp.eval_count, Some(5));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_with_options() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"model":"llama3.2:3b","response":"test","done":true}"#;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let opts = GenerateOptions {
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
num_predict: Some(100),
|
||||
stop: Some(vec!["STOP".to_string()]),
|
||||
};
|
||||
let resp = client
|
||||
.generate("llama3.2:3b", "Hi", Some(opts))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.response, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_stream_success() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let ndjson = concat!(
|
||||
r#"{"model":"llama3.2:3b","response":"Hello","done":false}"#, "\n",
|
||||
r#"{"model":"llama3.2:3b","response":" world","done":false}"#, "\n",
|
||||
r#"{"model":"llama3.2:3b","response":"!","done":false}"#, "\n",
|
||||
r#"{"model":"llama3.2:3b","response":"","done":true,"done_reason":"stop","eval_count":3}"#, "\n",
|
||||
);
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_raw(ndjson, "application/x-ndjson"),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let stream = client
|
||||
.generate_stream("llama3.2:3b", "Hi", None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let chunks: Vec<_> = stream.collect::<Vec<_>>().await;
|
||||
assert_eq!(chunks.len(), 4);
|
||||
assert_eq!(chunks[0].as_ref().unwrap().response, "Hello");
|
||||
assert_eq!(chunks[1].as_ref().unwrap().response, " world");
|
||||
assert_eq!(chunks[2].as_ref().unwrap().response, "!");
|
||||
assert!(chunks[3].as_ref().unwrap().done);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_stream_single_done() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let ndjson = r#"{"model":"llama3.2:3b","response":"","done":true,"done_reason":"stop"}"#;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_raw(ndjson, "application/x-ndjson"),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let stream = client
|
||||
.generate_stream("llama3.2:3b", "Hi", None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let chunks: Vec<_> = stream.collect::<Vec<_>>().await;
|
||||
assert_eq!(chunks.len(), 1);
|
||||
assert!(chunks[0].as_ref().unwrap().done);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chat_success() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"model":"llama3.2:3b","message":{"role":"assistant","content":"Hi there!"},"done":true,"done_reason":"stop"}"#;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/chat"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let msgs = vec![ChatMessage {
|
||||
role: ChatRole::User,
|
||||
content: "Hello".to_string(),
|
||||
}];
|
||||
let resp = client.chat("llama3.2:3b", msgs, None).await.unwrap();
|
||||
assert_eq!(resp.message.content, "Hi there!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_chat_with_history() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"model":"llama3.2:3b","message":{"role":"assistant","content":"Fine, thanks!"},"done":true}"#;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/chat"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let msgs = vec![
|
||||
ChatMessage {
|
||||
role: ChatRole::System,
|
||||
content: "Be helpful".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: ChatRole::User,
|
||||
content: "Hi".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: ChatRole::Assistant,
|
||||
content: "Hello!".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: ChatRole::User,
|
||||
content: "How are you?".to_string(),
|
||||
},
|
||||
];
|
||||
let resp = client.chat("llama3.2:3b", msgs, None).await.unwrap();
|
||||
assert_eq!(resp.message.content, "Fine, thanks!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embed_success() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2,0.3]]}"#;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/embed"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let resp = client
|
||||
.embed("nomic-embed-text", vec!["hello".to_string()])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.embeddings.len(), 1);
|
||||
assert_eq!(resp.embeddings[0].len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embed_multiple_inputs() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2],[0.3,0.4]]}"#;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/embed"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let resp = client
|
||||
.embed(
|
||||
"nomic-embed-text",
|
||||
vec!["hello".to_string(), "world".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.embeddings.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_models_success() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"models":[{"name":"llama3.2:3b","model":"llama3.2:3b","size":1234567},{"name":"nomic-embed-text","model":"nomic-embed-text","size":987654}]}"#;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/tags"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let resp = client.list_models().await.unwrap();
|
||||
assert_eq!(resp.models.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_models_empty() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"models":[]}"#;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/tags"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let resp = client.list_models().await.unwrap();
|
||||
assert!(resp.models.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_is_healthy_success() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"models":[]}"#;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/tags"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
assert!(client.is_healthy().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_is_healthy_failure() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/tags"))
|
||||
.respond_with(ResponseTemplate::new(500).set_body_string("internal error"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
assert!(!client.is_healthy().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_error_404() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(404).set_body_string("model 'nonexistent' not found"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let err = client
|
||||
.generate("nonexistent", "Hi", None)
|
||||
.await
|
||||
.unwrap_err();
|
||||
match err {
|
||||
OllamaError::Api { status, message } => {
|
||||
assert_eq!(status, 404);
|
||||
assert!(message.contains("not found"));
|
||||
}
|
||||
other => panic!("expected Api error, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_api_error_500() {
|
||||
let mock_server = MockServer::start().await;
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/api/generate"))
|
||||
.respond_with(ResponseTemplate::new(500).set_body_string("internal server error"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let client = test_client(&mock_server.uri());
|
||||
let err = client
|
||||
.generate("llama3.2:3b", "Hi", None)
|
||||
.await
|
||||
.unwrap_err();
|
||||
match err {
|
||||
OllamaError::Api { status, .. } => assert_eq!(status, 500),
|
||||
other => panic!("expected Api error, got: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_base_url_trailing_slash() {
|
||||
let mock_server = MockServer::start().await;
|
||||
let body = r#"{"models":[]}"#;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/api/tags"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
// URL with trailing slash
|
||||
let config = Config {
|
||||
ollama_url: format!("{}/", mock_server.uri()),
|
||||
..Config::default()
|
||||
};
|
||||
let client = OllamaClient::new(&config).unwrap();
|
||||
let resp = client.list_models().await.unwrap();
|
||||
assert!(resp.models.is_empty());
|
||||
}
|
||||
}
|
||||
51
services/model-gateway/src/ollama/error.rs
Normal file
51
services/model-gateway/src/ollama/error.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur when communicating with the Ollama API.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OllamaError {
|
||||
/// HTTP-level error (connection refused, timeout, DNS, etc.).
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
|
||||
/// Ollama returned a non-2xx status code.
|
||||
#[error("Ollama API error (status {status}): {message}")]
|
||||
Api { status: u16, message: String },
|
||||
|
||||
/// Failed to deserialize Ollama JSON response.
|
||||
#[error("deserialization error: {0}")]
|
||||
Deserialization(String),
|
||||
|
||||
/// Stream terminated unexpectedly without a done:true chunk.
|
||||
#[error("stream ended unexpectedly")]
|
||||
StreamIncomplete,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_display_api() {
|
||||
let err = OllamaError::Api {
|
||||
status: 404,
|
||||
message: "model not found".to_string(),
|
||||
};
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("404"));
|
||||
assert!(msg.contains("model not found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_display_deserialization() {
|
||||
let err = OllamaError::Deserialization("unexpected EOF".to_string());
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("unexpected EOF"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_display_stream_incomplete() {
|
||||
let err = OllamaError::StreamIncomplete;
|
||||
let msg = format!("{err}");
|
||||
assert!(msg.contains("stream ended unexpectedly"));
|
||||
}
|
||||
}
|
||||
6
services/model-gateway/src/ollama/mod.rs
Normal file
6
services/model-gateway/src/ollama/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod client;
|
||||
pub mod error;
|
||||
pub mod types;
|
||||
|
||||
pub use client::OllamaClient;
|
||||
pub use error::OllamaError;
|
||||
301
services/model-gateway/src/ollama/types.rs
Normal file
301
services/model-gateway/src/ollama/types.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// --- /api/generate ---
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GenerateRequest {
|
||||
pub model: String,
|
||||
pub prompt: String,
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<GenerateOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct GenerateOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub num_predict: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Full response from /api/generate with stream:false.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenerateResponse {
|
||||
pub model: String,
|
||||
pub response: String,
|
||||
pub done: bool,
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub total_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
/// Single chunk from /api/generate with stream:true (NDJSON).
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GenerateStreamChunk {
|
||||
pub model: String,
|
||||
pub response: String,
|
||||
pub done: bool,
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
// --- /api/chat ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ChatRole {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: ChatRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<GenerateOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatResponse {
|
||||
pub model: String,
|
||||
pub message: ChatMessage,
|
||||
pub done: bool,
|
||||
#[serde(default)]
|
||||
pub done_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub total_duration: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub eval_count: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub prompt_eval_count: Option<u32>,
|
||||
}
|
||||
|
||||
// --- /api/embed ---
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbedRequest {
|
||||
pub model: String,
|
||||
pub input: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbedResponse {
|
||||
pub model: String,
|
||||
pub embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
// --- /api/tags (list models) ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListModelsResponse {
|
||||
pub models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub name: String,
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub size: u64,
|
||||
#[serde(default)]
|
||||
pub digest: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_serialization() {
|
||||
let req = GenerateRequest {
|
||||
model: "llama3.2:3b".to_string(),
|
||||
prompt: "Hello".to_string(),
|
||||
stream: false,
|
||||
options: None,
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(json["model"], "llama3.2:3b");
|
||||
assert_eq!(json["prompt"], "Hello");
|
||||
assert_eq!(json["stream"], false);
|
||||
assert!(json.get("options").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_serialization_with_options() {
|
||||
let req = GenerateRequest {
|
||||
model: "llama3.2:3b".to_string(),
|
||||
prompt: "Hello".to_string(),
|
||||
stream: false,
|
||||
options: Some(GenerateOptions {
|
||||
temperature: Some(0.7),
|
||||
top_p: None,
|
||||
num_predict: Some(100),
|
||||
stop: Some(vec!["<|end|>".to_string()]),
|
||||
}),
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
let opts = &json["options"];
|
||||
assert!((opts["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
|
||||
assert!(opts.get("top_p").is_none());
|
||||
assert_eq!(opts["num_predict"], 100);
|
||||
assert_eq!(opts["stop"][0], "<|end|>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_response_deserialization() {
|
||||
let json = r#"{
|
||||
"model": "llama3.2:3b",
|
||||
"response": "Hello world!",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"total_duration": 1234567890,
|
||||
"eval_count": 42,
|
||||
"prompt_eval_count": 10
|
||||
}"#;
|
||||
let resp: GenerateResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.model, "llama3.2:3b");
|
||||
assert_eq!(resp.response, "Hello world!");
|
||||
assert!(resp.done);
|
||||
assert_eq!(resp.done_reason.as_deref(), Some("stop"));
|
||||
assert_eq!(resp.eval_count, Some(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_response_missing_optional_fields() {
|
||||
let json = r#"{
|
||||
"model": "llama3.2:3b",
|
||||
"response": "Hello",
|
||||
"done": true
|
||||
}"#;
|
||||
let resp: GenerateResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.done_reason.is_none());
|
||||
assert!(resp.total_duration.is_none());
|
||||
assert!(resp.eval_count.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_stream_chunk_deserialization() {
|
||||
let json = r#"{"model":"llama3.2:3b","response":"Hello","done":false}"#;
|
||||
let chunk: GenerateStreamChunk = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(chunk.response, "Hello");
|
||||
assert!(!chunk.done);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_stream_chunk_final() {
|
||||
let json = r#"{
|
||||
"model": "llama3.2:3b",
|
||||
"response": "",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"eval_count": 15
|
||||
}"#;
|
||||
let chunk: GenerateStreamChunk = serde_json::from_str(json).unwrap();
|
||||
assert!(chunk.done);
|
||||
assert_eq!(chunk.done_reason.as_deref(), Some("stop"));
|
||||
assert_eq!(chunk.eval_count, Some(15));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_request_serialization() {
|
||||
let req = ChatRequest {
|
||||
model: "llama3.2:3b".to_string(),
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: ChatRole::System,
|
||||
content: "You are helpful.".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: ChatRole::User,
|
||||
content: "Hello".to_string(),
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
options: None,
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(json["messages"].as_array().unwrap().len(), 2);
|
||||
assert_eq!(json["messages"][0]["role"], "system");
|
||||
assert_eq!(json["messages"][1]["role"], "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_role_serialization() {
|
||||
let json = serde_json::to_value(ChatRole::Assistant).unwrap();
|
||||
assert_eq!(json, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_response_deserialization() {
|
||||
let json = r#"{
|
||||
"model": "llama3.2:3b",
|
||||
"message": {"role": "assistant", "content": "Hi there!"},
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"eval_count": 5
|
||||
}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "Hi there!");
|
||||
assert!(resp.done);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_request_serialization() {
|
||||
let req = EmbedRequest {
|
||||
model: "nomic-embed-text".to_string(),
|
||||
input: vec!["hello".to_string(), "world".to_string()],
|
||||
};
|
||||
let json = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(json["model"], "nomic-embed-text");
|
||||
assert_eq!(json["input"].as_array().unwrap().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_response_deserialization() {
|
||||
let json = r#"{
|
||||
"model": "nomic-embed-text",
|
||||
"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
}"#;
|
||||
let resp: EmbedResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.embeddings.len(), 2);
|
||||
assert_eq!(resp.embeddings[0].len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_models_response_deserialization() {
|
||||
let json = r#"{
|
||||
"models": [
|
||||
{"name": "llama3.2:3b", "model": "llama3.2:3b", "size": 1234567, "digest": "abc123"},
|
||||
{"name": "nomic-embed-text", "model": "nomic-embed-text", "size": 987654}
|
||||
]
|
||||
}"#;
|
||||
let resp: ListModelsResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.models.len(), 2);
|
||||
assert_eq!(resp.models[0].name, "llama3.2:3b");
|
||||
assert_eq!(resp.models[0].digest.as_deref(), Some("abc123"));
|
||||
assert!(resp.models[1].digest.is_none());
|
||||
}
|
||||
}
|
||||
@@ -6,15 +6,19 @@ use llm_multiverse_proto::llm_multiverse::v1::{
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::ollama::OllamaClient;
|
||||
|
||||
/// Implementation of the ModelGatewayService gRPC trait.
|
||||
pub struct ModelGatewayServiceImpl {
|
||||
config: Config,
|
||||
#[allow(dead_code)]
|
||||
ollama: OllamaClient,
|
||||
}
|
||||
|
||||
impl ModelGatewayServiceImpl {
|
||||
pub fn new(config: Config) -> Self {
|
||||
Self { config }
|
||||
pub fn new(config: Config) -> Result<Self, anyhow::Error> {
|
||||
let ollama = OllamaClient::new(&config)?;
|
||||
Ok(Self { config, ollama })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +94,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_is_model_ready_all_models() {
|
||||
let svc = ModelGatewayServiceImpl::new(test_config());
|
||||
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
|
||||
let req = Request::new(IsModelReadyRequest { model_name: None });
|
||||
|
||||
let resp = svc.is_model_ready(req).await.unwrap().into_inner();
|
||||
@@ -105,7 +109,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_is_model_ready_specific_model_found() {
|
||||
let svc = ModelGatewayServiceImpl::new(test_config());
|
||||
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
|
||||
let req = Request::new(IsModelReadyRequest {
|
||||
model_name: Some("llama3.2:3b".to_string()),
|
||||
});
|
||||
@@ -118,7 +122,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_is_model_ready_specific_model_not_found() {
|
||||
let svc = ModelGatewayServiceImpl::new(test_config());
|
||||
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
|
||||
let req = Request::new(IsModelReadyRequest {
|
||||
model_name: Some("nonexistent:latest".to_string()),
|
||||
});
|
||||
@@ -137,7 +141,7 @@ mod tests {
|
||||
.aliases
|
||||
.insert("code".into(), "codellama:7b".into());
|
||||
|
||||
let svc = ModelGatewayServiceImpl::new(config);
|
||||
let svc = ModelGatewayServiceImpl::new(config).unwrap();
|
||||
let req = Request::new(IsModelReadyRequest {
|
||||
model_name: Some("codellama:7b".to_string()),
|
||||
});
|
||||
@@ -148,7 +152,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_inference_unimplemented() {
|
||||
let svc = ModelGatewayServiceImpl::new(test_config());
|
||||
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
|
||||
let req = Request::new(StreamInferenceRequest { params: None });
|
||||
|
||||
let result = svc.stream_inference(req).await;
|
||||
@@ -158,7 +162,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_inference_unimplemented() {
|
||||
let svc = ModelGatewayServiceImpl::new(test_config());
|
||||
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
|
||||
let req = Request::new(InferenceRequest { params: None });
|
||||
|
||||
let result = svc.inference(req).await;
|
||||
@@ -168,7 +172,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_embedding_unimplemented() {
|
||||
let svc = ModelGatewayServiceImpl::new(test_config());
|
||||
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
|
||||
let req = Request::new(GenerateEmbeddingRequest {
|
||||
context: None,
|
||||
text: "test".into(),
|
||||
|
||||
Reference in New Issue
Block a user