feat: implement Ollama HTTP client for Model Gateway (issue #39)

Add async HTTP client wrapping the Ollama REST API with:
- OllamaClient with generate, generate_stream, chat, embed, list_models, is_healthy
- NDJSON streaming parser for /api/generate streaming responses
- Serde types for all Ollama API endpoints
- OllamaError enum with Http, Api, Deserialization, StreamIncomplete variants
- OllamaClientConfig for timeout and connection pool settings
- Integration into ModelGatewayServiceImpl (constructor now returns Result)
- 48 tests (types serde, wiremock HTTP mocks, error handling, config)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Pi Agent
2026-03-10 13:56:53 +01:00
parent 0c55de22f2
commit a38ea1db51
12 changed files with 1966 additions and 10 deletions

290
Cargo.lock generated
View File

@@ -230,6 +230,16 @@ dependencies = [
"regex-syntax",
]
[[package]]
name = "assert-json-diff"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -490,6 +500,26 @@ dependencies = [
"tiny-keccak",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@@ -530,6 +560,24 @@ dependencies = [
"typenum",
]
[[package]]
name = "deadpool"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b"
dependencies = [
"deadpool-runtime",
"lazy_static",
"num_cpus",
"tokio",
]
[[package]]
name = "deadpool-runtime"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
[[package]]
name = "derive_arbitrary"
version = "1.4.2"
@@ -585,6 +633,15 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "equivalent"
version = "1.0.2"
@@ -665,6 +722,21 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
@@ -888,6 +960,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "http"
version = "1.4.0"
@@ -986,6 +1064,22 @@ dependencies = [
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.20"
@@ -1004,9 +1098,11 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"socket2 0.6.3",
"system-configuration",
"tokio",
"tower-service",
"tracing",
"windows-registry",
]
[[package]]
@@ -1455,10 +1551,14 @@ name = "model-gateway"
version = "0.1.0"
dependencies = [
"anyhow",
"bytes",
"futures",
"llm-multiverse-proto",
"prost",
"prost-types",
"reqwest",
"serde",
"serde_json",
"tempfile",
"thiserror",
"tokio",
@@ -1467,6 +1567,7 @@ dependencies = [
"tonic",
"tracing",
"tracing-subscriber",
"wiremock",
]
[[package]]
@@ -1475,6 +1576,23 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084"
[[package]]
name = "native-tls"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
@@ -1558,12 +1676,66 @@ dependencies = [
"libm",
]
[[package]]
name = "num_cpus"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "openssl-probe"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "parking_lot"
version = "0.12.5"
@@ -1970,17 +2142,22 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-channel",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"quinn",
@@ -1991,13 +2168,16 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tokio-util",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"webpki-roots",
]
@@ -2127,6 +2307,15 @@ version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
[[package]]
name = "schannel"
version = "0.1.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@@ -2161,6 +2350,29 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "security-framework"
version = "3.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
dependencies = [
"bitflags",
"core-foundation 0.10.1",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "semver"
version = "1.0.27"
@@ -2405,6 +2617,27 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "system-configuration"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b"
dependencies = [
"bitflags",
"core-foundation 0.9.4",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tap"
version = "1.0.1"
@@ -2526,6 +2759,16 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
@@ -2980,6 +3223,19 @@ dependencies = [
"wasmparser",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "wasmparser"
version = "0.244.0"
@@ -3062,6 +3318,17 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-registry"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720"
dependencies = [
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-result"
version = "0.4.1"
@@ -3171,6 +3438,29 @@ dependencies = [
"memchr",
]
[[package]]
name = "wiremock"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031"
dependencies = [
"assert-json-diff",
"base64",
"deadpool",
"futures",
"http",
"http-body-util",
"hyper",
"hyper-util",
"log",
"once_cell",
"regex",
"serde",
"serde_json",
"tokio",
"url",
]
[[package]]
name = "wit-bindgen"
version = "0.51.0"

View File

@@ -42,6 +42,7 @@
| #36 | Implement GetCorrelated gRPC endpoint | Phase 4 | `COMPLETED` | Rust | [issue-036.md](issue-036.md) |
| #37 | Integration tests for Memory Service | Phase 4 | `COMPLETED` | Rust | [issue-037.md](issue-037.md) |
| #38 | Scaffold Model Gateway Rust project | Phase 5 | `COMPLETED` | Rust | [issue-038.md](issue-038.md) |
| #39 | Implement Ollama HTTP client | Phase 5 | `IMPLEMENTING` | Rust | [issue-039.md](issue-039.md) |
## Status Legend
@@ -83,6 +84,8 @@
### Model Gateway
- [issue-012.md](issue-012.md) — model_gateway.proto (ModelGatewayService)
- [issue-038.md](issue-038.md) — Scaffold Model Gateway Rust project
- [issue-039.md](issue-039.md) — Ollama HTTP client (reqwest, streaming, embeddings)
### Search Service
- [issue-013.md](issue-013.md) — search.proto (SearchService)

View File

@@ -0,0 +1,631 @@
# Implementation Plan — Issue #39: Implement Ollama HTTP client
## Metadata
| Field | Value |
|---|---|
| Issue | [#39](https://git.shahondin1624.de/llm-multiverse/llm-multiverse/issues/39) |
| Title | Implement Ollama HTTP client |
| Milestone | Phase 5: Model Gateway |
| Labels | |
| Status | `IMPLEMENTING` |
| Language | Rust |
| Related Plans | [issue-012.md](issue-012.md), [issue-038.md](issue-038.md) |
| Blocked by | #38 (completed) |
## Acceptance Criteria
- [ ] Async HTTP client using reqwest
- [ ] Support for /api/generate (streaming and non-streaming)
- [ ] Support for /api/chat with message history
- [ ] Support for /api/embed (embeddings)
- [ ] Connection pooling and timeout configuration
- [ ] Error handling for Ollama-specific error responses
## Architecture Analysis
### Service Context
This issue belongs to the **Model Gateway** service (`services/model-gateway/`). The Ollama HTTP client is the core backend layer that the gRPC service handlers (`service.rs`) will call to fulfil `Inference`, `StreamInference`, `GenerateEmbedding`, and `IsModelReady` RPCs.
The client wraps the Ollama REST API (default `http://localhost:11434`) and exposes typed Rust methods that the gRPC handlers can call directly. The gRPC handlers (issue #40+) will translate proto request/response types to/from the Ollama client types defined here.
**gRPC endpoints affected (consumers of this client):**
- `Inference` — will call `OllamaClient::generate()` (non-streaming)
- `StreamInference` — will call `OllamaClient::generate_stream()` (streaming)
- `GenerateEmbedding` — will call `OllamaClient::embed()`
- `IsModelReady` — will call `OllamaClient::list_models()` to check actual Ollama availability
**Proto messages involved:**
- `InferenceParams` — carries prompt, model routing hints (TaskComplexity), temperature, top_p, max_tokens, stop_sequences
- `InferenceResponse` — text, finish_reason, tokens_used
- `StreamInferenceResponse` — token, finish_reason
- `GenerateEmbeddingRequest` — text, model
- `GenerateEmbeddingResponse` — embedding vector, dimensions
### Existing Patterns
- **Config:** `services/model-gateway/src/config.rs` already defines `Config` with `ollama_url: String` (default `http://localhost:11434`) and `ModelRoutingConfig` for model name resolution.
- **Service struct:** `services/model-gateway/src/service.rs` defines `ModelGatewayServiceImpl` holding `Config`. The `OllamaClient` will be added here as a field.
- **Error types:** Other services use `thiserror` for module-level error enums (e.g., `DbError`, `EmbeddingError`, `ProvenanceError`). The model-gateway `Cargo.toml` already includes `thiserror = "2"`.
- **Async runtime:** `tokio` with `features = ["full"]` is already a dependency. `tokio-stream = "0.1"` is also present.
- **Serde:** `serde = { version = "1", features = ["derive"] }` is already a dependency for config deserialization.
### Dependencies
- **reqwest** (new) — HTTP client with connection pooling, async support, JSON serialization, and streaming response bodies. Features needed: `json`, `stream`.
- **futures** (new) — For `Stream` trait and stream combinators (`futures::Stream`, `futures::StreamExt`). Needed to expose streaming generate responses as a `Stream` type.
- **serde_json** (new) — For parsing newline-delimited JSON (NDJSON) from Ollama streaming responses. While `reqwest` can deserialize full JSON responses, streaming requires manual line-by-line parsing.
- **No proto changes required** — the Ollama client is an internal HTTP layer; the proto definitions are already complete from issue #12.
## Implementation Steps
### 1. Types & Configuration
**Add Ollama-specific configuration to `services/model-gateway/src/config.rs`:**
```rust
/// Configuration for the Ollama HTTP client.
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaClientConfig {
/// Request timeout in seconds (default: 300 — generous for large model inference).
#[serde(default = "default_request_timeout_secs")]
pub request_timeout_secs: u64,
/// Connection timeout in seconds (default: 10).
#[serde(default = "default_connect_timeout_secs")]
pub connect_timeout_secs: u64,
/// Maximum idle connections in the pool (default: 10).
#[serde(default = "default_pool_max_idle")]
pub pool_max_idle: usize,
/// Idle connection timeout in seconds (default: 60).
#[serde(default = "default_pool_idle_timeout_secs")]
pub pool_idle_timeout_secs: u64,
}
```
Add `#[serde(default)] pub client: OllamaClientConfig` field to the existing `Config` struct.
**Define Ollama API request/response types in `services/model-gateway/src/ollama/types.rs`:**
These are serde structs matching the Ollama REST API JSON schema.
```rust
use serde::{Deserialize, Serialize};
// --- /api/generate ---
#[derive(Debug, Serialize)]
pub struct GenerateRequest {
pub model: String,
pub prompt: String,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<GenerateOptions>,
}
#[derive(Debug, Serialize)]
pub struct GenerateOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_predict: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
/// Full response from /api/generate with stream:false.
#[derive(Debug, Deserialize)]
pub struct GenerateResponse {
pub model: String,
pub response: String,
pub done: bool,
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u32>,
#[serde(default)]
pub prompt_eval_count: Option<u32>,
}
/// Single chunk from /api/generate with stream:true (NDJSON).
#[derive(Debug, Deserialize)]
pub struct GenerateStreamChunk {
pub model: String,
pub response: String,
pub done: bool,
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub eval_count: Option<u32>,
#[serde(default)]
pub prompt_eval_count: Option<u32>,
}
// --- /api/chat ---
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<GenerateOptions>,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub model: String,
pub message: ChatMessage,
pub done: bool,
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u32>,
#[serde(default)]
pub prompt_eval_count: Option<u32>,
}
// --- /api/embed ---
#[derive(Debug, Serialize)]
pub struct EmbedRequest {
pub model: String,
pub input: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct EmbedResponse {
pub model: String,
pub embeddings: Vec<Vec<f32>>,
}
// --- /api/tags (list models) ---
#[derive(Debug, Deserialize)]
pub struct ListModelsResponse {
pub models: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub model: String,
#[serde(default)]
pub size: u64,
#[serde(default)]
pub digest: Option<String>,
}
// --- /api/show (model details) ---
#[derive(Debug, Serialize)]
pub struct ShowModelRequest {
pub model: String,
}
#[derive(Debug, Deserialize)]
pub struct ShowModelResponse {
pub modelfile: Option<String>,
pub parameters: Option<String>,
pub template: Option<String>,
}
```
### 2. Core Logic
**Create `services/model-gateway/src/ollama/error.rs` — Error types:**
```rust
use thiserror::Error;
#[derive(Debug, Error)]
pub enum OllamaError {
/// HTTP-level error (connection refused, timeout, DNS, TLS, etc.).
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
/// Ollama returned a non-2xx status code.
#[error("Ollama API error (status {status}): {message}")]
Api {
status: u16,
message: String,
},
/// Failed to deserialize Ollama JSON response.
#[error("deserialization error: {0}")]
Deserialization(String),
/// Stream terminated unexpectedly without a done:true chunk.
#[error("stream ended unexpectedly")]
StreamIncomplete,
}
```
**Create `services/model-gateway/src/ollama/client.rs` — OllamaClient:**
```rust
use std::time::Duration;
use futures::Stream;
use reqwest::Client;
use crate::config::{Config, OllamaClientConfig};
use super::error::OllamaError;
use super::types::*;
/// Async HTTP client for the Ollama REST API.
///
/// Wraps `reqwest::Client` with connection pooling, timeouts, and
/// typed methods for each Ollama endpoint.
pub struct OllamaClient {
client: Client,
base_url: String,
}
impl OllamaClient {
/// Create a new client from the service configuration.
///
/// Configures connection pooling, timeouts, and the base URL
/// from `Config.ollama_url` and `Config.client`.
pub fn new(config: &Config) -> Result<Self, OllamaError> {
let client_config = &config.client;
let client = Client::builder()
.timeout(Duration::from_secs(client_config.request_timeout_secs))
.connect_timeout(Duration::from_secs(client_config.connect_timeout_secs))
.pool_max_idle_per_host(client_config.pool_max_idle)
.pool_idle_timeout(Duration::from_secs(client_config.pool_idle_timeout_secs))
.build()?;
let base_url = config.ollama_url.trim_end_matches('/').to_string();
Ok(Self { client, base_url })
}
/// POST /api/generate (non-streaming).
///
/// Sends a prompt to the specified model and returns the complete response.
pub async fn generate(
&self,
model: &str,
prompt: &str,
options: Option<GenerateOptions>,
) -> Result<GenerateResponse, OllamaError> {
let request = GenerateRequest {
model: model.to_string(),
prompt: prompt.to_string(),
stream: false,
options,
};
let resp = self.client
.post(format!("{}/api/generate", self.base_url))
.json(&request)
.send()
.await?;
self.handle_error_response(resp)
.await?
.json::<GenerateResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// POST /api/generate (streaming).
///
/// Returns a `Stream` of `GenerateStreamChunk` items. Each chunk
/// contains a partial token. The final chunk has `done: true`.
///
/// Ollama streams NDJSON (one JSON object per line). This method
/// reads the response body as a byte stream, splits on newlines,
/// and deserializes each line.
pub async fn generate_stream(
&self,
model: &str,
prompt: &str,
options: Option<GenerateOptions>,
) -> Result<
impl Stream<Item = Result<GenerateStreamChunk, OllamaError>>,
OllamaError,
> {
let request = GenerateRequest {
model: model.to_string(),
prompt: prompt.to_string(),
stream: true,
options,
};
let resp = self.client
.post(format!("{}/api/generate", self.base_url))
.json(&request)
.send()
.await?;
let resp = self.handle_error_response(resp).await?;
Ok(Self::ndjson_stream::<GenerateStreamChunk>(resp))
}
/// POST /api/chat (non-streaming).
///
/// Sends a chat conversation (message history) to the model.
pub async fn chat(
&self,
model: &str,
messages: Vec<ChatMessage>,
options: Option<GenerateOptions>,
) -> Result<ChatResponse, OllamaError> {
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options,
};
let resp = self.client
.post(format!("{}/api/chat", self.base_url))
.json(&request)
.send()
.await?;
self.handle_error_response(resp)
.await?
.json::<ChatResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// POST /api/embed.
///
/// Generates embedding vectors for the given input texts.
/// Returns one embedding vector per input string.
pub async fn embed(
&self,
model: &str,
input: Vec<String>,
) -> Result<EmbedResponse, OllamaError> {
let request = EmbedRequest {
model: model.to_string(),
input,
};
let resp = self.client
.post(format!("{}/api/embed", self.base_url))
.json(&request)
.send()
.await?;
self.handle_error_response(resp)
.await?
.json::<EmbedResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// GET /api/tags.
///
/// Lists all models available on the Ollama instance.
pub async fn list_models(&self) -> Result<ListModelsResponse, OllamaError> {
let resp = self.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await?;
self.handle_error_response(resp)
.await?
.json::<ListModelsResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// Check if Ollama is reachable by hitting GET /api/tags.
/// Returns true if the request succeeds, false otherwise.
pub async fn is_healthy(&self) -> bool {
self.list_models().await.is_ok()
}
/// Parse NDJSON streaming response into a Stream of typed chunks.
///
/// Ollama streams responses as newline-delimited JSON. Each line
/// is a complete JSON object. This method uses `bytes_stream()`
/// from reqwest and buffers bytes until a newline is found,
/// then deserializes each complete line.
fn ndjson_stream<T: serde::de::DeserializeOwned>(
resp: reqwest::Response,
) -> impl Stream<Item = Result<T, OllamaError>> {
use futures::StreamExt;
let byte_stream = resp.bytes_stream();
let mut buffer = Vec::new();
futures::stream::unfold(
(byte_stream, buffer),
|(mut stream, mut buf)| async move {
// Implementation: accumulate bytes, split on \n,
// deserialize each complete line as T.
// Return None when stream ends.
// ...
},
)
}
/// Check response status and extract error message for non-2xx responses.
async fn handle_error_response(
&self,
resp: reqwest::Response,
) -> Result<reqwest::Response, OllamaError> {
if resp.status().is_success() {
return Ok(resp);
}
let status = resp.status().as_u16();
let message = resp
.text()
.await
.unwrap_or_else(|_| "unknown error".to_string());
Err(OllamaError::Api { status, message })
}
}
```
**NDJSON stream implementation detail:**
The `ndjson_stream` method will use `reqwest::Response::bytes_stream()` (requires the `stream` feature) and `futures::stream::unfold` to:
1. Accumulate bytes from the HTTP response body into a buffer.
2. On each newline boundary, extract the complete line.
3. Deserialize the line as `T` using `serde_json::from_slice`.
4. Yield `Ok(T)` or `Err(OllamaError::Deserialization(...))`.
5. Return `None` when the byte stream is exhausted.
This approach handles partial JSON objects that span multiple TCP chunks correctly.
### 3. gRPC Handler Wiring
This issue does **not** implement the gRPC handler wiring — that is deferred to subsequent issues. However, the `OllamaClient` must be integrated into `ModelGatewayServiceImpl` so that future handler implementations can use it.
**Update `services/model-gateway/src/service.rs`:**
Add `OllamaClient` as a field on `ModelGatewayServiceImpl`:
```rust
use crate::ollama::OllamaClient;
pub struct ModelGatewayServiceImpl {
config: Config,
ollama: OllamaClient,
}
impl ModelGatewayServiceImpl {
pub fn new(config: Config) -> anyhow::Result<Self> {
let ollama = OllamaClient::new(&config)?;
Ok(Self { config, ollama })
}
}
```
Note: The constructor changes from infallible to `Result` since `reqwest::Client::builder().build()` can fail. Update `main.rs` accordingly to use `?`.
**Update `services/model-gateway/src/main.rs`:**
Change `ModelGatewayServiceImpl::new(config)` to `ModelGatewayServiceImpl::new(config)?`.
### 4. Service Integration
No cross-service integration is needed for this issue. The `OllamaClient` is a standalone HTTP client that talks to the local Ollama instance. Integration with gRPC handlers will happen in follow-up issues.
### 5. Tests
**Unit tests for serde types in `services/model-gateway/src/ollama/types.rs`:**
| Test Case | Description |
|---|---|
| `test_generate_request_serialization` | `GenerateRequest` serializes to expected JSON with `stream: false` |
| `test_generate_request_serialization_with_options` | Options fields are included when `Some`, omitted when `None` |
| `test_generate_response_deserialization` | Deserialize a complete Ollama generate response JSON |
| `test_generate_response_missing_optional_fields` | Optional fields default to `None` when absent |
| `test_generate_stream_chunk_deserialization` | Deserialize a streaming chunk (partial token, `done: false`) |
| `test_generate_stream_chunk_final` | Deserialize final chunk with `done: true` and `done_reason` |
| `test_chat_request_serialization` | `ChatRequest` with multiple messages serializes correctly |
| `test_chat_role_serialization` | `ChatRole` variants serialize as lowercase strings |
| `test_chat_response_deserialization` | Deserialize a complete chat response |
| `test_embed_request_serialization` | `EmbedRequest` with multiple inputs serializes correctly |
| `test_embed_response_deserialization` | Deserialize embedding response with vector data |
| `test_list_models_response_deserialization` | Deserialize model listing with multiple models |
| `test_model_info_optional_fields` | `ModelInfo` handles missing `digest` gracefully |
**Unit tests for error handling in `services/model-gateway/src/ollama/error.rs`:**
| Test Case | Description |
|---|---|
| `test_error_display_http` | `OllamaError::Http` formats with reqwest message |
| `test_error_display_api` | `OllamaError::Api` includes status code and message |
| `test_error_display_deserialization` | `OllamaError::Deserialization` includes detail |
**Integration-style tests for `OllamaClient` in `services/model-gateway/src/ollama/client.rs`:**
Use a mock HTTP server (either `mockito` or `wiremock`) to simulate Ollama API responses:
| Test Case | Description |
|---|---|
| `test_generate_success` | Mock `/api/generate` returns valid JSON, verify parsed response |
| `test_generate_with_options` | Verify temperature, top_p, num_predict, stop are sent in request body |
| `test_generate_stream_success` | Mock returns NDJSON with 3 chunks + final, verify all chunks yielded |
| `test_generate_stream_empty_response` | Mock returns single `done: true` chunk |
| `test_chat_success` | Mock `/api/chat` returns valid response, verify message parsing |
| `test_chat_with_history` | Send multi-message conversation, verify all messages in request body |
| `test_embed_success` | Mock `/api/embed` returns embedding vectors, verify dimensions |
| `test_embed_multiple_inputs` | Send multiple texts, verify multiple embeddings returned |
| `test_list_models_success` | Mock `/api/tags` returns model list |
| `test_list_models_empty` | Mock returns empty model list |
| `test_is_healthy_success` | Mock `/api/tags` returns 200, `is_healthy()` returns true |
| `test_is_healthy_failure` | Mock returns 500, `is_healthy()` returns false |
| `test_api_error_404` | Mock returns 404 with error message, verify `OllamaError::Api` |
| `test_api_error_500` | Mock returns 500 with error body, verify error extraction |
| `test_connection_timeout` | Client configured with very short timeout, verify `OllamaError::Http` |
| `test_base_url_trailing_slash` | Config URL with trailing slash is normalized |
**Mocking strategy:**
Use `wiremock` crate as a dev-dependency. It provides a `MockServer` that binds to a random port, allowing parallel test execution without port conflicts. Each test creates its own `MockServer`, configures expected requests/responses, then creates an `OllamaClient` pointed at the mock server URL.
For streaming tests, the mock server returns a response body containing multiple NDJSON lines separated by `\n`.
**Configuration tests in `services/model-gateway/src/config.rs`:**
| Test Case | Description |
|---|---|
| `test_client_config_defaults` | `OllamaClientConfig::default()` returns expected timeout/pool values |
| `test_client_config_from_toml` | Custom client config loads from TOML |
| `test_config_with_client_section` | Full `Config` with `[client]` section parses correctly |
## Files to Create/Modify
| File | Action | Purpose |
|---|---|---|
| `services/model-gateway/Cargo.toml` | Modify | Add `reqwest` (with `json`, `stream` features), `futures`, `serde_json` dependencies; add `wiremock` dev-dependency |
| `services/model-gateway/src/config.rs` | Modify | Add `OllamaClientConfig` struct with timeout/pool settings; add `client` field to `Config` |
| `services/model-gateway/src/ollama/mod.rs` | Create | Module declaration, re-exports of `OllamaClient`, `OllamaError`, and types |
| `services/model-gateway/src/ollama/types.rs` | Create | Serde request/response structs for all Ollama API endpoints |
| `services/model-gateway/src/ollama/error.rs` | Create | `OllamaError` enum with `Http`, `Api`, `Deserialization`, `StreamIncomplete` variants |
| `services/model-gateway/src/ollama/client.rs` | Create | `OllamaClient` struct with `generate`, `generate_stream`, `chat`, `embed`, `list_models`, `is_healthy` methods and NDJSON stream parser |
| `services/model-gateway/src/lib.rs` | Modify | Add `pub mod ollama;` |
| `services/model-gateway/src/service.rs` | Modify | Add `OllamaClient` field to `ModelGatewayServiceImpl`; change constructor to return `Result` |
| `services/model-gateway/src/main.rs` | Modify | Update `ModelGatewayServiceImpl::new(config)` call to handle `Result` with `?` |
## Risks and Edge Cases
- **Streaming NDJSON parsing:** Ollama sends newline-delimited JSON. TCP chunks may not align with JSON object boundaries — a single chunk could contain a partial JSON line or multiple lines. The buffer-based `ndjson_stream` implementation must handle both cases. Mitigation: accumulate bytes until `\n` is found, only parse complete lines.
- **Large model response times:** Inference on large models (14B+) can take minutes. The default request timeout of 300 seconds should be sufficient, but this is configurable. Streaming mitigates perceived latency by yielding tokens incrementally.
- **Ollama API version compatibility:** The `/api/embed` endpoint (with `input` array) was introduced in Ollama 0.1.44+. Older Ollama versions use `/api/embeddings` with a different request shape. Mitigation: target the newer API. Document the minimum Ollama version requirement.
- **Connection pool exhaustion:** If many concurrent gRPC requests hit the gateway simultaneously, the reqwest connection pool could be exhausted. Mitigation: `pool_max_idle` is configurable; the default of 10 is reasonable for a single-node setup. Consider adding a semaphore for concurrency limiting in a future issue if needed.
- **`wiremock` test isolation:** Each test creates its own `MockServer` on a random port, so tests can run in parallel safely. However, `wiremock` adds to dev-dependency compile time.
- **Constructor change breaks existing tests:** Changing `ModelGatewayServiceImpl::new()` from infallible to `Result` will break existing tests in `service.rs`. Mitigation: update the test helper `test_config()` to also construct the `OllamaClient`, or use a test-only constructor that accepts a pre-built client. Alternatively, keep a separate `new_with_client()` constructor for testability and dependency injection.
- **reqwest TLS:** The default reqwest build pulls in `rustls` or `native-tls`. Since Ollama runs locally over plain HTTP, TLS is not needed. Consider using `default-features = false` with just the required features to minimize compile time and binary size. However, if a user runs Ollama behind a TLS reverse proxy, TLS support is needed. Mitigation: use default features (includes TLS) for now; optimize later if compile time is a concern.
## Deviation Log
_(Filled during implementation if deviations from plan occur)_
| Deviation | Reason |
|---|---|

View File

@@ -18,6 +18,11 @@ toml = "0.8"
anyhow = "1"
thiserror = "2"
tokio-stream = "0.1"
reqwest = { version = "0.12", features = ["json", "stream"] }
futures = "0.3"
serde_json = "1"
bytes = "1"
[dev-dependencies]
tempfile = "3"
wiremock = "0.6"

View File

@@ -71,6 +71,53 @@ impl ModelRoutingConfig {
}
}
/// Configuration for the Ollama HTTP client.
#[derive(Debug, Clone, Deserialize)]
pub struct OllamaClientConfig {
/// Request timeout in seconds (default: 300).
#[serde(default = "default_request_timeout_secs")]
pub request_timeout_secs: u64,
/// Connection timeout in seconds (default: 10).
#[serde(default = "default_connect_timeout_secs")]
pub connect_timeout_secs: u64,
/// Maximum idle connections in the pool (default: 10).
#[serde(default = "default_pool_max_idle")]
pub pool_max_idle: usize,
/// Idle connection timeout in seconds (default: 60).
#[serde(default = "default_pool_idle_timeout_secs")]
pub pool_idle_timeout_secs: u64,
}
fn default_request_timeout_secs() -> u64 {
300
}
fn default_connect_timeout_secs() -> u64 {
10
}
fn default_pool_max_idle() -> usize {
10
}
fn default_pool_idle_timeout_secs() -> u64 {
60
}
impl Default for OllamaClientConfig {
fn default() -> Self {
Self {
request_timeout_secs: default_request_timeout_secs(),
connect_timeout_secs: default_connect_timeout_secs(),
pool_max_idle: default_pool_max_idle(),
pool_idle_timeout_secs: default_pool_idle_timeout_secs(),
}
}
}
/// Top-level configuration for the Model Gateway service.
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
@@ -92,6 +139,10 @@ pub struct Config {
/// Model routing configuration.
#[serde(default)]
pub routing: ModelRoutingConfig,
/// Ollama HTTP client configuration.
#[serde(default)]
pub client: OllamaClientConfig,
}
fn default_host() -> String {
@@ -114,6 +165,7 @@ impl Default for Config {
ollama_url: default_ollama_url(),
audit_addr: None,
routing: ModelRoutingConfig::default(),
client: OllamaClientConfig::default(),
}
}
}
@@ -230,6 +282,48 @@ code = "codellama:7b"
assert_eq!(count, 1);
}
#[test]
fn test_client_config_defaults() {
let cc = OllamaClientConfig::default();
assert_eq!(cc.request_timeout_secs, 300);
assert_eq!(cc.connect_timeout_secs, 10);
assert_eq!(cc.pool_max_idle, 10);
assert_eq!(cc.pool_idle_timeout_secs, 60);
}
#[test]
fn test_client_config_from_toml() {
let dir = tempfile::tempdir().unwrap();
let config_path = dir.path().join("gateway.toml");
std::fs::write(
&config_path,
r#"
host = "0.0.0.0"
port = 9999
[client]
request_timeout_secs = 600
connect_timeout_secs = 5
pool_max_idle = 20
pool_idle_timeout_secs = 120
"#,
)
.unwrap();
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
assert_eq!(config.client.request_timeout_secs, 600);
assert_eq!(config.client.connect_timeout_secs, 5);
assert_eq!(config.client.pool_max_idle, 20);
assert_eq!(config.client.pool_idle_timeout_secs, 120);
}
#[test]
fn test_client_config_uses_defaults_when_omitted() {
let config = Config::default();
assert_eq!(config.client.request_timeout_secs, 300);
assert_eq!(config.client.connect_timeout_secs, 10);
}
#[test]
fn test_routing_from_toml_uses_defaults_when_omitted() {
let dir = tempfile::tempdir().unwrap();

View File

@@ -1,2 +1,3 @@
pub mod config;
pub mod ollama;
pub mod service;

View File

@@ -23,7 +23,7 @@ async fn main() -> anyhow::Result<()> {
);
let addr = config.listen_addr().parse()?;
let service = ModelGatewayServiceImpl::new(config);
let service = ModelGatewayServiceImpl::new(config)?;
tracing::info!(%addr, "Model Gateway listening");

View File

@@ -0,0 +1,570 @@
use std::pin::Pin;
use std::time::Duration;
use bytes::Bytes;
use futures::stream::Stream;
use futures::StreamExt;
use reqwest::Client;
use super::error::OllamaError;
use super::types::*;
use crate::config::Config;
/// Async HTTP client for the Ollama REST API.
///
/// Wraps `reqwest::Client` with connection pooling, timeouts, and
/// typed methods for each Ollama endpoint.
pub struct OllamaClient {
client: Client,
base_url: String,
}
impl OllamaClient {
/// Create a new client from the service configuration.
pub fn new(config: &Config) -> Result<Self, OllamaError> {
let cc = &config.client;
let client = Client::builder()
.timeout(Duration::from_secs(cc.request_timeout_secs))
.connect_timeout(Duration::from_secs(cc.connect_timeout_secs))
.pool_max_idle_per_host(cc.pool_max_idle)
.pool_idle_timeout(Duration::from_secs(cc.pool_idle_timeout_secs))
.build()?;
let base_url = config.ollama_url.trim_end_matches('/').to_string();
Ok(Self { client, base_url })
}
/// POST /api/generate (non-streaming).
pub async fn generate(
&self,
model: &str,
prompt: &str,
options: Option<GenerateOptions>,
) -> Result<GenerateResponse, OllamaError> {
let request = GenerateRequest {
model: model.to_string(),
prompt: prompt.to_string(),
stream: false,
options,
};
let resp = self
.client
.post(format!("{}/api/generate", self.base_url))
.json(&request)
.send()
.await?;
let resp = Self::check_status(resp).await?;
resp.json::<GenerateResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// POST /api/generate (streaming).
///
/// Returns a `Stream` of `GenerateStreamChunk` items parsed from NDJSON.
pub async fn generate_stream(
&self,
model: &str,
prompt: &str,
options: Option<GenerateOptions>,
) -> Result<
Pin<Box<dyn Stream<Item = Result<GenerateStreamChunk, OllamaError>> + Send>>,
OllamaError,
> {
let request = GenerateRequest {
model: model.to_string(),
prompt: prompt.to_string(),
stream: true,
options,
};
let resp = self
.client
.post(format!("{}/api/generate", self.base_url))
.json(&request)
.send()
.await?;
let resp = Self::check_status(resp).await?;
Ok(Self::ndjson_stream::<GenerateStreamChunk>(resp))
}
/// POST /api/chat (non-streaming).
pub async fn chat(
&self,
model: &str,
messages: Vec<ChatMessage>,
options: Option<GenerateOptions>,
) -> Result<ChatResponse, OllamaError> {
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options,
};
let resp = self
.client
.post(format!("{}/api/chat", self.base_url))
.json(&request)
.send()
.await?;
let resp = Self::check_status(resp).await?;
resp.json::<ChatResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// POST /api/embed.
pub async fn embed(
&self,
model: &str,
input: Vec<String>,
) -> Result<EmbedResponse, OllamaError> {
let request = EmbedRequest {
model: model.to_string(),
input,
};
let resp = self
.client
.post(format!("{}/api/embed", self.base_url))
.json(&request)
.send()
.await?;
let resp = Self::check_status(resp).await?;
resp.json::<EmbedResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// GET /api/tags — list all models available on the Ollama instance.
pub async fn list_models(&self) -> Result<ListModelsResponse, OllamaError> {
let resp = self
.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await?;
let resp = Self::check_status(resp).await?;
resp.json::<ListModelsResponse>()
.await
.map_err(|e| OllamaError::Deserialization(e.to_string()))
}
/// Check if Ollama is reachable.
pub async fn is_healthy(&self) -> bool {
self.list_models().await.is_ok()
}
/// Parse an NDJSON response body into a typed stream.
///
/// Ollama streams newline-delimited JSON. Each line is a complete
/// JSON object. This buffers bytes until newlines are found, then
/// deserializes each complete line.
fn ndjson_stream<T: serde::de::DeserializeOwned + Send + 'static>(
resp: reqwest::Response,
) -> Pin<Box<dyn Stream<Item = Result<T, OllamaError>> + Send>> {
let byte_stream = resp.bytes_stream();
type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>;
Box::pin(futures::stream::unfold(
(Box::pin(byte_stream) as ByteStream, Vec::<u8>::new()),
|(mut stream, mut buf): (ByteStream, Vec<u8>)| async move {
loop {
// Check if buffer contains a complete line.
if let Some(pos) = buf.iter().position(|&b| b == b'\n') {
let line = buf.drain(..=pos).collect::<Vec<_>>();
let line = &line[..line.len() - 1]; // strip newline
if line.is_empty() {
continue; // skip empty lines
}
let result = serde_json::from_slice::<T>(line)
.map_err(|e| OllamaError::Deserialization(e.to_string()));
return Some((result, (stream, buf)));
}
// Need more data from the stream.
match stream.next().await {
Some(Ok(chunk)) => {
buf.extend_from_slice(&chunk);
}
Some(Err(e)) => {
return Some((Err(OllamaError::Http(e)), (stream, buf)));
}
None => {
// Stream ended. Parse any remaining data in buffer.
if buf.is_empty() {
return None;
}
// Trim trailing whitespace
let trimmed = buf.trim_ascii();
if trimmed.is_empty() {
return None;
}
let result = serde_json::from_slice::<T>(trimmed)
.map_err(|e| OllamaError::Deserialization(e.to_string()));
buf.clear();
return Some((result, (stream, buf)));
}
}
}
},
))
}
/// Check response status and extract error message for non-2xx.
async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, OllamaError> {
if resp.status().is_success() {
return Ok(resp);
}
let status = resp.status().as_u16();
let message = resp
.text()
.await
.unwrap_or_else(|_| "unknown error".to_string());
Err(OllamaError::Api { status, message })
}
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn test_client(base_url: &str) -> OllamaClient {
let config = Config {
ollama_url: base_url.to_string(),
..Config::default()
};
OllamaClient::new(&config).unwrap()
}
#[tokio::test]
async fn test_generate_success() {
let mock_server = MockServer::start().await;
let body = r#"{"model":"llama3.2:3b","response":"Hello world!","done":true,"done_reason":"stop","eval_count":5}"#;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let resp = client.generate("llama3.2:3b", "Hi", None).await.unwrap();
assert_eq!(resp.response, "Hello world!");
assert!(resp.done);
assert_eq!(resp.eval_count, Some(5));
}
#[tokio::test]
async fn test_generate_with_options() {
let mock_server = MockServer::start().await;
let body = r#"{"model":"llama3.2:3b","response":"test","done":true}"#;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let opts = GenerateOptions {
temperature: Some(0.7),
top_p: Some(0.9),
num_predict: Some(100),
stop: Some(vec!["STOP".to_string()]),
};
let resp = client
.generate("llama3.2:3b", "Hi", Some(opts))
.await
.unwrap();
assert_eq!(resp.response, "test");
}
#[tokio::test]
async fn test_generate_stream_success() {
let mock_server = MockServer::start().await;
let ndjson = concat!(
r#"{"model":"llama3.2:3b","response":"Hello","done":false}"#, "\n",
r#"{"model":"llama3.2:3b","response":" world","done":false}"#, "\n",
r#"{"model":"llama3.2:3b","response":"!","done":false}"#, "\n",
r#"{"model":"llama3.2:3b","response":"","done":true,"done_reason":"stop","eval_count":3}"#, "\n",
);
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(
ResponseTemplate::new(200).set_body_raw(ndjson, "application/x-ndjson"),
)
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let stream = client
.generate_stream("llama3.2:3b", "Hi", None)
.await
.unwrap();
let chunks: Vec<_> = stream.collect::<Vec<_>>().await;
assert_eq!(chunks.len(), 4);
assert_eq!(chunks[0].as_ref().unwrap().response, "Hello");
assert_eq!(chunks[1].as_ref().unwrap().response, " world");
assert_eq!(chunks[2].as_ref().unwrap().response, "!");
assert!(chunks[3].as_ref().unwrap().done);
}
#[tokio::test]
async fn test_generate_stream_single_done() {
let mock_server = MockServer::start().await;
let ndjson = r#"{"model":"llama3.2:3b","response":"","done":true,"done_reason":"stop"}"#;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(
ResponseTemplate::new(200).set_body_raw(ndjson, "application/x-ndjson"),
)
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let stream = client
.generate_stream("llama3.2:3b", "Hi", None)
.await
.unwrap();
let chunks: Vec<_> = stream.collect::<Vec<_>>().await;
assert_eq!(chunks.len(), 1);
assert!(chunks[0].as_ref().unwrap().done);
}
#[tokio::test]
async fn test_chat_success() {
let mock_server = MockServer::start().await;
let body = r#"{"model":"llama3.2:3b","message":{"role":"assistant","content":"Hi there!"},"done":true,"done_reason":"stop"}"#;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let msgs = vec![ChatMessage {
role: ChatRole::User,
content: "Hello".to_string(),
}];
let resp = client.chat("llama3.2:3b", msgs, None).await.unwrap();
assert_eq!(resp.message.content, "Hi there!");
}
#[tokio::test]
async fn test_chat_with_history() {
let mock_server = MockServer::start().await;
let body = r#"{"model":"llama3.2:3b","message":{"role":"assistant","content":"Fine, thanks!"},"done":true}"#;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let msgs = vec![
ChatMessage {
role: ChatRole::System,
content: "Be helpful".to_string(),
},
ChatMessage {
role: ChatRole::User,
content: "Hi".to_string(),
},
ChatMessage {
role: ChatRole::Assistant,
content: "Hello!".to_string(),
},
ChatMessage {
role: ChatRole::User,
content: "How are you?".to_string(),
},
];
let resp = client.chat("llama3.2:3b", msgs, None).await.unwrap();
assert_eq!(resp.message.content, "Fine, thanks!");
}
#[tokio::test]
async fn test_embed_success() {
let mock_server = MockServer::start().await;
let body = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2,0.3]]}"#;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let resp = client
.embed("nomic-embed-text", vec!["hello".to_string()])
.await
.unwrap();
assert_eq!(resp.embeddings.len(), 1);
assert_eq!(resp.embeddings[0].len(), 3);
}
#[tokio::test]
async fn test_embed_multiple_inputs() {
let mock_server = MockServer::start().await;
let body = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2],[0.3,0.4]]}"#;
Mock::given(method("POST"))
.and(path("/api/embed"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let resp = client
.embed(
"nomic-embed-text",
vec!["hello".to_string(), "world".to_string()],
)
.await
.unwrap();
assert_eq!(resp.embeddings.len(), 2);
}
#[tokio::test]
async fn test_list_models_success() {
let mock_server = MockServer::start().await;
let body = r#"{"models":[{"name":"llama3.2:3b","model":"llama3.2:3b","size":1234567},{"name":"nomic-embed-text","model":"nomic-embed-text","size":987654}]}"#;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let resp = client.list_models().await.unwrap();
assert_eq!(resp.models.len(), 2);
}
#[tokio::test]
async fn test_list_models_empty() {
let mock_server = MockServer::start().await;
let body = r#"{"models":[]}"#;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let resp = client.list_models().await.unwrap();
assert!(resp.models.is_empty());
}
#[tokio::test]
async fn test_is_healthy_success() {
let mock_server = MockServer::start().await;
let body = r#"{"models":[]}"#;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
assert!(client.is_healthy().await);
}
#[tokio::test]
async fn test_is_healthy_failure() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(500).set_body_string("internal error"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
assert!(!client.is_healthy().await);
}
#[tokio::test]
async fn test_api_error_404() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(404).set_body_string("model 'nonexistent' not found"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let err = client
.generate("nonexistent", "Hi", None)
.await
.unwrap_err();
match err {
OllamaError::Api { status, message } => {
assert_eq!(status, 404);
assert!(message.contains("not found"));
}
other => panic!("expected Api error, got: {other:?}"),
}
}
#[tokio::test]
async fn test_api_error_500() {
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/generate"))
.respond_with(ResponseTemplate::new(500).set_body_string("internal server error"))
.mount(&mock_server)
.await;
let client = test_client(&mock_server.uri());
let err = client
.generate("llama3.2:3b", "Hi", None)
.await
.unwrap_err();
match err {
OllamaError::Api { status, .. } => assert_eq!(status, 500),
other => panic!("expected Api error, got: {other:?}"),
}
}
#[tokio::test]
async fn test_base_url_trailing_slash() {
let mock_server = MockServer::start().await;
let body = r#"{"models":[]}"#;
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json"))
.mount(&mock_server)
.await;
// URL with trailing slash
let config = Config {
ollama_url: format!("{}/", mock_server.uri()),
..Config::default()
};
let client = OllamaClient::new(&config).unwrap();
let resp = client.list_models().await.unwrap();
assert!(resp.models.is_empty());
}
}

View File

@@ -0,0 +1,51 @@
use thiserror::Error;
/// Errors that can occur when communicating with the Ollama API.
#[derive(Debug, Error)]
pub enum OllamaError {
/// HTTP-level error (connection refused, timeout, DNS, etc.).
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
/// Ollama returned a non-2xx status code.
#[error("Ollama API error (status {status}): {message}")]
Api { status: u16, message: String },
/// Failed to deserialize Ollama JSON response.
#[error("deserialization error: {0}")]
Deserialization(String),
/// Stream terminated unexpectedly without a done:true chunk.
#[error("stream ended unexpectedly")]
StreamIncomplete,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display_api() {
let err = OllamaError::Api {
status: 404,
message: "model not found".to_string(),
};
let msg = format!("{err}");
assert!(msg.contains("404"));
assert!(msg.contains("model not found"));
}
#[test]
fn test_error_display_deserialization() {
let err = OllamaError::Deserialization("unexpected EOF".to_string());
let msg = format!("{err}");
assert!(msg.contains("unexpected EOF"));
}
#[test]
fn test_error_display_stream_incomplete() {
let err = OllamaError::StreamIncomplete;
let msg = format!("{err}");
assert!(msg.contains("stream ended unexpectedly"));
}
}

View File

@@ -0,0 +1,6 @@
pub mod client;
pub mod error;
pub mod types;
pub use client::OllamaClient;
pub use error::OllamaError;

View File

@@ -0,0 +1,301 @@
use serde::{Deserialize, Serialize};
// --- /api/generate ---
#[derive(Debug, Serialize)]
pub struct GenerateRequest {
pub model: String,
pub prompt: String,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<GenerateOptions>,
}
#[derive(Debug, Serialize)]
pub struct GenerateOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_predict: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
/// Full response from /api/generate with stream:false.
#[derive(Debug, Deserialize)]
pub struct GenerateResponse {
pub model: String,
pub response: String,
pub done: bool,
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u32>,
#[serde(default)]
pub prompt_eval_count: Option<u32>,
}
/// Single chunk from /api/generate with stream:true (NDJSON).
#[derive(Debug, Deserialize)]
pub struct GenerateStreamChunk {
pub model: String,
pub response: String,
pub done: bool,
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub eval_count: Option<u32>,
#[serde(default)]
pub prompt_eval_count: Option<u32>,
}
// --- /api/chat ---
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<GenerateOptions>,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub model: String,
pub message: ChatMessage,
pub done: bool,
#[serde(default)]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: Option<u64>,
#[serde(default)]
pub eval_count: Option<u32>,
#[serde(default)]
pub prompt_eval_count: Option<u32>,
}
// --- /api/embed ---
#[derive(Debug, Serialize)]
pub struct EmbedRequest {
pub model: String,
pub input: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct EmbedResponse {
pub model: String,
pub embeddings: Vec<Vec<f32>>,
}
// --- /api/tags (list models) ---
#[derive(Debug, Deserialize)]
pub struct ListModelsResponse {
pub models: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub model: String,
#[serde(default)]
pub size: u64,
#[serde(default)]
pub digest: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_request_serialization() {
let req = GenerateRequest {
model: "llama3.2:3b".to_string(),
prompt: "Hello".to_string(),
stream: false,
options: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "llama3.2:3b");
assert_eq!(json["prompt"], "Hello");
assert_eq!(json["stream"], false);
assert!(json.get("options").is_none());
}
#[test]
fn test_generate_request_serialization_with_options() {
let req = GenerateRequest {
model: "llama3.2:3b".to_string(),
prompt: "Hello".to_string(),
stream: false,
options: Some(GenerateOptions {
temperature: Some(0.7),
top_p: None,
num_predict: Some(100),
stop: Some(vec!["<|end|>".to_string()]),
}),
};
let json = serde_json::to_value(&req).unwrap();
let opts = &json["options"];
assert!((opts["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001);
assert!(opts.get("top_p").is_none());
assert_eq!(opts["num_predict"], 100);
assert_eq!(opts["stop"][0], "<|end|>");
}
#[test]
fn test_generate_response_deserialization() {
let json = r#"{
"model": "llama3.2:3b",
"response": "Hello world!",
"done": true,
"done_reason": "stop",
"total_duration": 1234567890,
"eval_count": 42,
"prompt_eval_count": 10
}"#;
let resp: GenerateResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.model, "llama3.2:3b");
assert_eq!(resp.response, "Hello world!");
assert!(resp.done);
assert_eq!(resp.done_reason.as_deref(), Some("stop"));
assert_eq!(resp.eval_count, Some(42));
}
#[test]
fn test_generate_response_missing_optional_fields() {
let json = r#"{
"model": "llama3.2:3b",
"response": "Hello",
"done": true
}"#;
let resp: GenerateResponse = serde_json::from_str(json).unwrap();
assert!(resp.done_reason.is_none());
assert!(resp.total_duration.is_none());
assert!(resp.eval_count.is_none());
}
#[test]
fn test_generate_stream_chunk_deserialization() {
let json = r#"{"model":"llama3.2:3b","response":"Hello","done":false}"#;
let chunk: GenerateStreamChunk = serde_json::from_str(json).unwrap();
assert_eq!(chunk.response, "Hello");
assert!(!chunk.done);
}
#[test]
fn test_generate_stream_chunk_final() {
let json = r#"{
"model": "llama3.2:3b",
"response": "",
"done": true,
"done_reason": "stop",
"eval_count": 15
}"#;
let chunk: GenerateStreamChunk = serde_json::from_str(json).unwrap();
assert!(chunk.done);
assert_eq!(chunk.done_reason.as_deref(), Some("stop"));
assert_eq!(chunk.eval_count, Some(15));
}
#[test]
fn test_chat_request_serialization() {
let req = ChatRequest {
model: "llama3.2:3b".to_string(),
messages: vec![
ChatMessage {
role: ChatRole::System,
content: "You are helpful.".to_string(),
},
ChatMessage {
role: ChatRole::User,
content: "Hello".to_string(),
},
],
stream: false,
options: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["messages"].as_array().unwrap().len(), 2);
assert_eq!(json["messages"][0]["role"], "system");
assert_eq!(json["messages"][1]["role"], "user");
}
#[test]
fn test_chat_role_serialization() {
let json = serde_json::to_value(ChatRole::Assistant).unwrap();
assert_eq!(json, "assistant");
}
#[test]
fn test_chat_response_deserialization() {
let json = r#"{
"model": "llama3.2:3b",
"message": {"role": "assistant", "content": "Hi there!"},
"done": true,
"done_reason": "stop",
"eval_count": 5
}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "Hi there!");
assert!(resp.done);
}
#[test]
fn test_embed_request_serialization() {
let req = EmbedRequest {
model: "nomic-embed-text".to_string(),
input: vec!["hello".to_string(), "world".to_string()],
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "nomic-embed-text");
assert_eq!(json["input"].as_array().unwrap().len(), 2);
}
#[test]
fn test_embed_response_deserialization() {
let json = r#"{
"model": "nomic-embed-text",
"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
}"#;
let resp: EmbedResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.embeddings.len(), 2);
assert_eq!(resp.embeddings[0].len(), 3);
}
#[test]
fn test_list_models_response_deserialization() {
let json = r#"{
"models": [
{"name": "llama3.2:3b", "model": "llama3.2:3b", "size": 1234567, "digest": "abc123"},
{"name": "nomic-embed-text", "model": "nomic-embed-text", "size": 987654}
]
}"#;
let resp: ListModelsResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.models.len(), 2);
assert_eq!(resp.models[0].name, "llama3.2:3b");
assert_eq!(resp.models[0].digest.as_deref(), Some("abc123"));
assert!(resp.models[1].digest.is_none());
}
}

View File

@@ -6,15 +6,19 @@ use llm_multiverse_proto::llm_multiverse::v1::{
use tonic::{Request, Response, Status};
use crate::config::Config;
use crate::ollama::OllamaClient;
/// Implementation of the ModelGatewayService gRPC trait.
pub struct ModelGatewayServiceImpl {
config: Config,
#[allow(dead_code)]
ollama: OllamaClient,
}
impl ModelGatewayServiceImpl {
pub fn new(config: Config) -> Self {
Self { config }
pub fn new(config: Config) -> Result<Self, anyhow::Error> {
let ollama = OllamaClient::new(&config)?;
Ok(Self { config, ollama })
}
}
@@ -90,7 +94,7 @@ mod tests {
#[tokio::test]
async fn test_is_model_ready_all_models() {
let svc = ModelGatewayServiceImpl::new(test_config());
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
let req = Request::new(IsModelReadyRequest { model_name: None });
let resp = svc.is_model_ready(req).await.unwrap().into_inner();
@@ -105,7 +109,7 @@ mod tests {
#[tokio::test]
async fn test_is_model_ready_specific_model_found() {
let svc = ModelGatewayServiceImpl::new(test_config());
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
let req = Request::new(IsModelReadyRequest {
model_name: Some("llama3.2:3b".to_string()),
});
@@ -118,7 +122,7 @@ mod tests {
#[tokio::test]
async fn test_is_model_ready_specific_model_not_found() {
let svc = ModelGatewayServiceImpl::new(test_config());
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
let req = Request::new(IsModelReadyRequest {
model_name: Some("nonexistent:latest".to_string()),
});
@@ -137,7 +141,7 @@ mod tests {
.aliases
.insert("code".into(), "codellama:7b".into());
let svc = ModelGatewayServiceImpl::new(config);
let svc = ModelGatewayServiceImpl::new(config).unwrap();
let req = Request::new(IsModelReadyRequest {
model_name: Some("codellama:7b".to_string()),
});
@@ -148,7 +152,7 @@ mod tests {
#[tokio::test]
async fn test_stream_inference_unimplemented() {
let svc = ModelGatewayServiceImpl::new(test_config());
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
let req = Request::new(StreamInferenceRequest { params: None });
let result = svc.stream_inference(req).await;
@@ -158,7 +162,7 @@ mod tests {
#[tokio::test]
async fn test_inference_unimplemented() {
let svc = ModelGatewayServiceImpl::new(test_config());
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
let req = Request::new(InferenceRequest { params: None });
let result = svc.inference(req).await;
@@ -168,7 +172,7 @@ mod tests {
#[tokio::test]
async fn test_generate_embedding_unimplemented() {
let svc = ModelGatewayServiceImpl::new(test_config());
let svc = ModelGatewayServiceImpl::new(test_config()).unwrap();
let req = Request::new(GenerateEmbeddingRequest {
context: None,
text: "test".into(),