diff --git a/Cargo.lock b/Cargo.lock index a11f04b..ed384a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,6 +230,16 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -490,6 +500,26 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -530,6 +560,24 @@ dependencies = [ "typenum", ] +[[package]] +name = "deadpool" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" +dependencies = [ + "deadpool-runtime", + "lazy_static", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" + [[package]] name = "derive_arbitrary" version = "1.4.2" @@ -585,6 +633,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -665,6 +722,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -888,6 +960,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "http" version = "1.4.0" @@ -986,6 +1064,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.20" @@ -1004,9 +1098,11 @@ dependencies = [ "percent-encoding", "pin-project-lite", "socket2 0.6.3", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -1455,10 +1551,14 @@ name = "model-gateway" version = "0.1.0" dependencies = [ "anyhow", + "bytes", + "futures", "llm-multiverse-proto", "prost", "prost-types", + "reqwest", "serde", + "serde_json", "tempfile", "thiserror", "tokio", @@ -1467,6 +1567,7 @@ dependencies = [ "tonic", "tracing", "tracing-subscriber", + "wiremock", ] [[package]] @@ -1475,6 +1576,23 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1558,12 +1676,66 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "parking_lot" version = "0.12.5" @@ -1970,17 +2142,22 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64", "bytes", + "encoding_rs", "futures-channel", "futures-core", "futures-util", + "h2", "http", "http-body", "http-body-util", "hyper", "hyper-rustls", + "hyper-tls", "hyper-util", "js-sys", "log", + "mime", + "native-tls", "percent-encoding", "pin-project-lite", "quinn", @@ -1991,13 +2168,16 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", ] @@ -2127,6 +2307,15 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2161,6 +2350,29 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.27" @@ -2405,6 +2617,27 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "system-configuration" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tap" version = "1.0.1" @@ -2526,6 +2759,16 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -2980,6 +3223,19 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.244.0" @@ -3062,6 +3318,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + [[package]] name = "windows-result" version = "0.4.1" @@ -3171,6 +3438,29 @@ dependencies = [ "memchr", ] +[[package]] +name = "wiremock" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08db1edfb05d9b3c1542e521aea074442088292f00b5f28e435c714a98f85031" +dependencies = [ + "assert-json-diff", + "base64", + "deadpool", + "futures", + "http", + "http-body-util", + "hyper", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", + "url", +] + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/implementation-plans/_index.md b/implementation-plans/_index.md index 2974d44..e1d637c 100644 --- a/implementation-plans/_index.md +++ b/implementation-plans/_index.md @@ -42,6 +42,7 @@ | #36 | Implement GetCorrelated gRPC endpoint | Phase 4 | `COMPLETED` | Rust | [issue-036.md](issue-036.md) | | #37 | Integration tests for Memory Service | Phase 4 | `COMPLETED` | Rust | [issue-037.md](issue-037.md) | | #38 | Scaffold Model Gateway Rust project | Phase 5 | `COMPLETED` | Rust | [issue-038.md](issue-038.md) | +| #39 | Implement Ollama HTTP client | Phase 5 | `IMPLEMENTING` | Rust | [issue-039.md](issue-039.md) | ## Status Legend @@ -83,6 +84,8 @@ ### Model Gateway - [issue-012.md](issue-012.md) — model_gateway.proto (ModelGatewayService) +- [issue-038.md](issue-038.md) — Scaffold Model Gateway Rust project +- [issue-039.md](issue-039.md) — Ollama HTTP client (reqwest, streaming, embeddings) ### Search Service - [issue-013.md](issue-013.md) — search.proto (SearchService) diff --git a/implementation-plans/issue-039.md b/implementation-plans/issue-039.md new file mode 100644 index 0000000..4edc9ab --- /dev/null +++ b/implementation-plans/issue-039.md @@ -0,0 +1,631 @@ +# Implementation Plan — Issue #39: Implement Ollama HTTP client + +## Metadata + +| Field | Value | +|---|---| +| Issue | [#39](https://git.shahondin1624.de/llm-multiverse/llm-multiverse/issues/39) | +| Title | Implement Ollama HTTP client | +| Milestone | Phase 5: Model Gateway | +| Labels | | +| Status | `IMPLEMENTING` | +| Language | Rust | +| Related Plans | [issue-012.md](issue-012.md), [issue-038.md](issue-038.md) | +| Blocked by | #38 (completed) | + +## Acceptance Criteria + +- [ ] Async HTTP client using reqwest +- [ ] Support for /api/generate (streaming and non-streaming) +- [ ] Support for /api/chat with message history +- [ ] Support for /api/embed (embeddings) +- [ ] Connection pooling and timeout configuration +- [ ] Error handling for Ollama-specific error responses + +## Architecture Analysis + +### Service Context + +This issue belongs to the **Model Gateway** service (`services/model-gateway/`). The Ollama HTTP client is the core backend layer that the gRPC service handlers (`service.rs`) will call to fulfil `Inference`, `StreamInference`, `GenerateEmbedding`, and `IsModelReady` RPCs. + +The client wraps the Ollama REST API (default `http://localhost:11434`) and exposes typed Rust methods that the gRPC handlers can call directly. The gRPC handlers (issue #40+) will translate proto request/response types to/from the Ollama client types defined here. + +**gRPC endpoints affected (consumers of this client):** +- `Inference` — will call `OllamaClient::generate()` (non-streaming) +- `StreamInference` — will call `OllamaClient::generate_stream()` (streaming) +- `GenerateEmbedding` — will call `OllamaClient::embed()` +- `IsModelReady` — will call `OllamaClient::list_models()` to check actual Ollama availability + +**Proto messages involved:** +- `InferenceParams` — carries prompt, model routing hints (TaskComplexity), temperature, top_p, max_tokens, stop_sequences +- `InferenceResponse` — text, finish_reason, tokens_used +- `StreamInferenceResponse` — token, finish_reason +- `GenerateEmbeddingRequest` — text, model +- `GenerateEmbeddingResponse` — embedding vector, dimensions + +### Existing Patterns + +- **Config:** `services/model-gateway/src/config.rs` already defines `Config` with `ollama_url: String` (default `http://localhost:11434`) and `ModelRoutingConfig` for model name resolution. +- **Service struct:** `services/model-gateway/src/service.rs` defines `ModelGatewayServiceImpl` holding `Config`. The `OllamaClient` will be added here as a field. +- **Error types:** Other services use `thiserror` for module-level error enums (e.g., `DbError`, `EmbeddingError`, `ProvenanceError`). The model-gateway `Cargo.toml` already includes `thiserror = "2"`. +- **Async runtime:** `tokio` with `features = ["full"]` is already a dependency. `tokio-stream = "0.1"` is also present. +- **Serde:** `serde = { version = "1", features = ["derive"] }` is already a dependency for config deserialization. + +### Dependencies + +- **reqwest** (new) — HTTP client with connection pooling, async support, JSON serialization, and streaming response bodies. Features needed: `json`, `stream`. +- **futures** (new) — For `Stream` trait and stream combinators (`futures::Stream`, `futures::StreamExt`). Needed to expose streaming generate responses as a `Stream` type. +- **serde_json** (new) — For parsing newline-delimited JSON (NDJSON) from Ollama streaming responses. While `reqwest` can deserialize full JSON responses, streaming requires manual line-by-line parsing. +- **No proto changes required** — the Ollama client is an internal HTTP layer; the proto definitions are already complete from issue #12. + +## Implementation Steps + +### 1. Types & Configuration + +**Add Ollama-specific configuration to `services/model-gateway/src/config.rs`:** + +```rust +/// Configuration for the Ollama HTTP client. +#[derive(Debug, Clone, Deserialize)] +pub struct OllamaClientConfig { + /// Request timeout in seconds (default: 300 — generous for large model inference). + #[serde(default = "default_request_timeout_secs")] + pub request_timeout_secs: u64, + + /// Connection timeout in seconds (default: 10). + #[serde(default = "default_connect_timeout_secs")] + pub connect_timeout_secs: u64, + + /// Maximum idle connections in the pool (default: 10). + #[serde(default = "default_pool_max_idle")] + pub pool_max_idle: usize, + + /// Idle connection timeout in seconds (default: 60). + #[serde(default = "default_pool_idle_timeout_secs")] + pub pool_idle_timeout_secs: u64, +} +``` + +Add `#[serde(default)] pub client: OllamaClientConfig` field to the existing `Config` struct. + +**Define Ollama API request/response types in `services/model-gateway/src/ollama/types.rs`:** + +These are serde structs matching the Ollama REST API JSON schema. + +```rust +use serde::{Deserialize, Serialize}; + +// --- /api/generate --- + +#[derive(Debug, Serialize)] +pub struct GenerateRequest { + pub model: String, + pub prompt: String, + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +#[derive(Debug, Serialize)] +pub struct GenerateOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_predict: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, +} + +/// Full response from /api/generate with stream:false. +#[derive(Debug, Deserialize)] +pub struct GenerateResponse { + pub model: String, + pub response: String, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub prompt_eval_count: Option, +} + +/// Single chunk from /api/generate with stream:true (NDJSON). +#[derive(Debug, Deserialize)] +pub struct GenerateStreamChunk { + pub model: String, + pub response: String, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub prompt_eval_count: Option, +} + +// --- /api/chat --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ChatRole { + System, + User, + Assistant, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: ChatRole, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ChatResponse { + pub model: String, + pub message: ChatMessage, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub prompt_eval_count: Option, +} + +// --- /api/embed --- + +#[derive(Debug, Serialize)] +pub struct EmbedRequest { + pub model: String, + pub input: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct EmbedResponse { + pub model: String, + pub embeddings: Vec>, +} + +// --- /api/tags (list models) --- + +#[derive(Debug, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct ModelInfo { + pub name: String, + pub model: String, + #[serde(default)] + pub size: u64, + #[serde(default)] + pub digest: Option, +} + +// --- /api/show (model details) --- + +#[derive(Debug, Serialize)] +pub struct ShowModelRequest { + pub model: String, +} + +#[derive(Debug, Deserialize)] +pub struct ShowModelResponse { + pub modelfile: Option, + pub parameters: Option, + pub template: Option, +} +``` + +### 2. Core Logic + +**Create `services/model-gateway/src/ollama/error.rs` — Error types:** + +```rust +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum OllamaError { + /// HTTP-level error (connection refused, timeout, DNS, TLS, etc.). + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + + /// Ollama returned a non-2xx status code. + #[error("Ollama API error (status {status}): {message}")] + Api { + status: u16, + message: String, + }, + + /// Failed to deserialize Ollama JSON response. + #[error("deserialization error: {0}")] + Deserialization(String), + + /// Stream terminated unexpectedly without a done:true chunk. + #[error("stream ended unexpectedly")] + StreamIncomplete, +} +``` + +**Create `services/model-gateway/src/ollama/client.rs` — OllamaClient:** + +```rust +use std::time::Duration; +use futures::Stream; +use reqwest::Client; + +use crate::config::{Config, OllamaClientConfig}; +use super::error::OllamaError; +use super::types::*; + +/// Async HTTP client for the Ollama REST API. +/// +/// Wraps `reqwest::Client` with connection pooling, timeouts, and +/// typed methods for each Ollama endpoint. +pub struct OllamaClient { + client: Client, + base_url: String, +} + +impl OllamaClient { + /// Create a new client from the service configuration. + /// + /// Configures connection pooling, timeouts, and the base URL + /// from `Config.ollama_url` and `Config.client`. + pub fn new(config: &Config) -> Result { + let client_config = &config.client; + let client = Client::builder() + .timeout(Duration::from_secs(client_config.request_timeout_secs)) + .connect_timeout(Duration::from_secs(client_config.connect_timeout_secs)) + .pool_max_idle_per_host(client_config.pool_max_idle) + .pool_idle_timeout(Duration::from_secs(client_config.pool_idle_timeout_secs)) + .build()?; + + let base_url = config.ollama_url.trim_end_matches('/').to_string(); + + Ok(Self { client, base_url }) + } + + /// POST /api/generate (non-streaming). + /// + /// Sends a prompt to the specified model and returns the complete response. + pub async fn generate( + &self, + model: &str, + prompt: &str, + options: Option, + ) -> Result { + let request = GenerateRequest { + model: model.to_string(), + prompt: prompt.to_string(), + stream: false, + options, + }; + + let resp = self.client + .post(format!("{}/api/generate", self.base_url)) + .json(&request) + .send() + .await?; + + self.handle_error_response(resp) + .await? + .json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// POST /api/generate (streaming). + /// + /// Returns a `Stream` of `GenerateStreamChunk` items. Each chunk + /// contains a partial token. The final chunk has `done: true`. + /// + /// Ollama streams NDJSON (one JSON object per line). This method + /// reads the response body as a byte stream, splits on newlines, + /// and deserializes each line. + pub async fn generate_stream( + &self, + model: &str, + prompt: &str, + options: Option, + ) -> Result< + impl Stream>, + OllamaError, + > { + let request = GenerateRequest { + model: model.to_string(), + prompt: prompt.to_string(), + stream: true, + options, + }; + + let resp = self.client + .post(format!("{}/api/generate", self.base_url)) + .json(&request) + .send() + .await?; + + let resp = self.handle_error_response(resp).await?; + Ok(Self::ndjson_stream::(resp)) + } + + /// POST /api/chat (non-streaming). + /// + /// Sends a chat conversation (message history) to the model. + pub async fn chat( + &self, + model: &str, + messages: Vec, + options: Option, + ) -> Result { + let request = ChatRequest { + model: model.to_string(), + messages, + stream: false, + options, + }; + + let resp = self.client + .post(format!("{}/api/chat", self.base_url)) + .json(&request) + .send() + .await?; + + self.handle_error_response(resp) + .await? + .json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// POST /api/embed. + /// + /// Generates embedding vectors for the given input texts. + /// Returns one embedding vector per input string. + pub async fn embed( + &self, + model: &str, + input: Vec, + ) -> Result { + let request = EmbedRequest { + model: model.to_string(), + input, + }; + + let resp = self.client + .post(format!("{}/api/embed", self.base_url)) + .json(&request) + .send() + .await?; + + self.handle_error_response(resp) + .await? + .json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// GET /api/tags. + /// + /// Lists all models available on the Ollama instance. + pub async fn list_models(&self) -> Result { + let resp = self.client + .get(format!("{}/api/tags", self.base_url)) + .send() + .await?; + + self.handle_error_response(resp) + .await? + .json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// Check if Ollama is reachable by hitting GET /api/tags. + /// Returns true if the request succeeds, false otherwise. + pub async fn is_healthy(&self) -> bool { + self.list_models().await.is_ok() + } + + /// Parse NDJSON streaming response into a Stream of typed chunks. + /// + /// Ollama streams responses as newline-delimited JSON. Each line + /// is a complete JSON object. This method uses `bytes_stream()` + /// from reqwest and buffers bytes until a newline is found, + /// then deserializes each complete line. + fn ndjson_stream( + resp: reqwest::Response, + ) -> impl Stream> { + use futures::StreamExt; + + let byte_stream = resp.bytes_stream(); + let mut buffer = Vec::new(); + + futures::stream::unfold( + (byte_stream, buffer), + |(mut stream, mut buf)| async move { + // Implementation: accumulate bytes, split on \n, + // deserialize each complete line as T. + // Return None when stream ends. + // ... + }, + ) + } + + /// Check response status and extract error message for non-2xx responses. + async fn handle_error_response( + &self, + resp: reqwest::Response, + ) -> Result { + if resp.status().is_success() { + return Ok(resp); + } + + let status = resp.status().as_u16(); + let message = resp + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + + Err(OllamaError::Api { status, message }) + } +} +``` + +**NDJSON stream implementation detail:** + +The `ndjson_stream` method will use `reqwest::Response::bytes_stream()` (requires the `stream` feature) and `futures::stream::unfold` to: +1. Accumulate bytes from the HTTP response body into a buffer. +2. On each newline boundary, extract the complete line. +3. Deserialize the line as `T` using `serde_json::from_slice`. +4. Yield `Ok(T)` or `Err(OllamaError::Deserialization(...))`. +5. Return `None` when the byte stream is exhausted. + +This approach handles partial JSON objects that span multiple TCP chunks correctly. + +### 3. gRPC Handler Wiring + +This issue does **not** implement the gRPC handler wiring — that is deferred to subsequent issues. However, the `OllamaClient` must be integrated into `ModelGatewayServiceImpl` so that future handler implementations can use it. + +**Update `services/model-gateway/src/service.rs`:** + +Add `OllamaClient` as a field on `ModelGatewayServiceImpl`: + +```rust +use crate::ollama::OllamaClient; + +pub struct ModelGatewayServiceImpl { + config: Config, + ollama: OllamaClient, +} + +impl ModelGatewayServiceImpl { + pub fn new(config: Config) -> anyhow::Result { + let ollama = OllamaClient::new(&config)?; + Ok(Self { config, ollama }) + } +} +``` + +Note: The constructor changes from infallible to `Result` since `reqwest::Client::builder().build()` can fail. Update `main.rs` accordingly to use `?`. + +**Update `services/model-gateway/src/main.rs`:** + +Change `ModelGatewayServiceImpl::new(config)` to `ModelGatewayServiceImpl::new(config)?`. + +### 4. Service Integration + +No cross-service integration is needed for this issue. The `OllamaClient` is a standalone HTTP client that talks to the local Ollama instance. Integration with gRPC handlers will happen in follow-up issues. + +### 5. Tests + +**Unit tests for serde types in `services/model-gateway/src/ollama/types.rs`:** + +| Test Case | Description | +|---|---| +| `test_generate_request_serialization` | `GenerateRequest` serializes to expected JSON with `stream: false` | +| `test_generate_request_serialization_with_options` | Options fields are included when `Some`, omitted when `None` | +| `test_generate_response_deserialization` | Deserialize a complete Ollama generate response JSON | +| `test_generate_response_missing_optional_fields` | Optional fields default to `None` when absent | +| `test_generate_stream_chunk_deserialization` | Deserialize a streaming chunk (partial token, `done: false`) | +| `test_generate_stream_chunk_final` | Deserialize final chunk with `done: true` and `done_reason` | +| `test_chat_request_serialization` | `ChatRequest` with multiple messages serializes correctly | +| `test_chat_role_serialization` | `ChatRole` variants serialize as lowercase strings | +| `test_chat_response_deserialization` | Deserialize a complete chat response | +| `test_embed_request_serialization` | `EmbedRequest` with multiple inputs serializes correctly | +| `test_embed_response_deserialization` | Deserialize embedding response with vector data | +| `test_list_models_response_deserialization` | Deserialize model listing with multiple models | +| `test_model_info_optional_fields` | `ModelInfo` handles missing `digest` gracefully | + +**Unit tests for error handling in `services/model-gateway/src/ollama/error.rs`:** + +| Test Case | Description | +|---|---| +| `test_error_display_http` | `OllamaError::Http` formats with reqwest message | +| `test_error_display_api` | `OllamaError::Api` includes status code and message | +| `test_error_display_deserialization` | `OllamaError::Deserialization` includes detail | + +**Integration-style tests for `OllamaClient` in `services/model-gateway/src/ollama/client.rs`:** + +Use a mock HTTP server (either `mockito` or `wiremock`) to simulate Ollama API responses: + +| Test Case | Description | +|---|---| +| `test_generate_success` | Mock `/api/generate` returns valid JSON, verify parsed response | +| `test_generate_with_options` | Verify temperature, top_p, num_predict, stop are sent in request body | +| `test_generate_stream_success` | Mock returns NDJSON with 3 chunks + final, verify all chunks yielded | +| `test_generate_stream_empty_response` | Mock returns single `done: true` chunk | +| `test_chat_success` | Mock `/api/chat` returns valid response, verify message parsing | +| `test_chat_with_history` | Send multi-message conversation, verify all messages in request body | +| `test_embed_success` | Mock `/api/embed` returns embedding vectors, verify dimensions | +| `test_embed_multiple_inputs` | Send multiple texts, verify multiple embeddings returned | +| `test_list_models_success` | Mock `/api/tags` returns model list | +| `test_list_models_empty` | Mock returns empty model list | +| `test_is_healthy_success` | Mock `/api/tags` returns 200, `is_healthy()` returns true | +| `test_is_healthy_failure` | Mock returns 500, `is_healthy()` returns false | +| `test_api_error_404` | Mock returns 404 with error message, verify `OllamaError::Api` | +| `test_api_error_500` | Mock returns 500 with error body, verify error extraction | +| `test_connection_timeout` | Client configured with very short timeout, verify `OllamaError::Http` | +| `test_base_url_trailing_slash` | Config URL with trailing slash is normalized | + +**Mocking strategy:** + +Use `wiremock` crate as a dev-dependency. It provides a `MockServer` that binds to a random port, allowing parallel test execution without port conflicts. Each test creates its own `MockServer`, configures expected requests/responses, then creates an `OllamaClient` pointed at the mock server URL. + +For streaming tests, the mock server returns a response body containing multiple NDJSON lines separated by `\n`. + +**Configuration tests in `services/model-gateway/src/config.rs`:** + +| Test Case | Description | +|---|---| +| `test_client_config_defaults` | `OllamaClientConfig::default()` returns expected timeout/pool values | +| `test_client_config_from_toml` | Custom client config loads from TOML | +| `test_config_with_client_section` | Full `Config` with `[client]` section parses correctly | + +## Files to Create/Modify + +| File | Action | Purpose | +|---|---|---| +| `services/model-gateway/Cargo.toml` | Modify | Add `reqwest` (with `json`, `stream` features), `futures`, `serde_json` dependencies; add `wiremock` dev-dependency | +| `services/model-gateway/src/config.rs` | Modify | Add `OllamaClientConfig` struct with timeout/pool settings; add `client` field to `Config` | +| `services/model-gateway/src/ollama/mod.rs` | Create | Module declaration, re-exports of `OllamaClient`, `OllamaError`, and types | +| `services/model-gateway/src/ollama/types.rs` | Create | Serde request/response structs for all Ollama API endpoints | +| `services/model-gateway/src/ollama/error.rs` | Create | `OllamaError` enum with `Http`, `Api`, `Deserialization`, `StreamIncomplete` variants | +| `services/model-gateway/src/ollama/client.rs` | Create | `OllamaClient` struct with `generate`, `generate_stream`, `chat`, `embed`, `list_models`, `is_healthy` methods and NDJSON stream parser | +| `services/model-gateway/src/lib.rs` | Modify | Add `pub mod ollama;` | +| `services/model-gateway/src/service.rs` | Modify | Add `OllamaClient` field to `ModelGatewayServiceImpl`; change constructor to return `Result` | +| `services/model-gateway/src/main.rs` | Modify | Update `ModelGatewayServiceImpl::new(config)` call to handle `Result` with `?` | + +## Risks and Edge Cases + +- **Streaming NDJSON parsing:** Ollama sends newline-delimited JSON. TCP chunks may not align with JSON object boundaries — a single chunk could contain a partial JSON line or multiple lines. The buffer-based `ndjson_stream` implementation must handle both cases. Mitigation: accumulate bytes until `\n` is found, only parse complete lines. +- **Large model response times:** Inference on large models (14B+) can take minutes. The default request timeout of 300 seconds should be sufficient, but this is configurable. Streaming mitigates perceived latency by yielding tokens incrementally. +- **Ollama API version compatibility:** The `/api/embed` endpoint (with `input` array) was introduced in Ollama 0.1.44+. Older Ollama versions use `/api/embeddings` with a different request shape. Mitigation: target the newer API. Document the minimum Ollama version requirement. +- **Connection pool exhaustion:** If many concurrent gRPC requests hit the gateway simultaneously, the reqwest connection pool could be exhausted. Mitigation: `pool_max_idle` is configurable; the default of 10 is reasonable for a single-node setup. Consider adding a semaphore for concurrency limiting in a future issue if needed. +- **`wiremock` test isolation:** Each test creates its own `MockServer` on a random port, so tests can run in parallel safely. However, `wiremock` adds to dev-dependency compile time. +- **Constructor change breaks existing tests:** Changing `ModelGatewayServiceImpl::new()` from infallible to `Result` will break existing tests in `service.rs`. Mitigation: update the test helper `test_config()` to also construct the `OllamaClient`, or use a test-only constructor that accepts a pre-built client. Alternatively, keep a separate `new_with_client()` constructor for testability and dependency injection. +- **reqwest TLS:** The default reqwest build pulls in `rustls` or `native-tls`. Since Ollama runs locally over plain HTTP, TLS is not needed. Consider using `default-features = false` with just the required features to minimize compile time and binary size. However, if a user runs Ollama behind a TLS reverse proxy, TLS support is needed. Mitigation: use default features (includes TLS) for now; optimize later if compile time is a concern. + +## Deviation Log + +_(Filled during implementation if deviations from plan occur)_ + +| Deviation | Reason | +|---|---| diff --git a/services/model-gateway/Cargo.toml b/services/model-gateway/Cargo.toml index 243e1e2..674f22d 100644 --- a/services/model-gateway/Cargo.toml +++ b/services/model-gateway/Cargo.toml @@ -18,6 +18,11 @@ toml = "0.8" anyhow = "1" thiserror = "2" tokio-stream = "0.1" +reqwest = { version = "0.12", features = ["json", "stream"] } +futures = "0.3" +serde_json = "1" +bytes = "1" [dev-dependencies] tempfile = "3" +wiremock = "0.6" diff --git a/services/model-gateway/src/config.rs b/services/model-gateway/src/config.rs index e8a638f..fdc0b97 100644 --- a/services/model-gateway/src/config.rs +++ b/services/model-gateway/src/config.rs @@ -71,6 +71,53 @@ impl ModelRoutingConfig { } } +/// Configuration for the Ollama HTTP client. +#[derive(Debug, Clone, Deserialize)] +pub struct OllamaClientConfig { + /// Request timeout in seconds (default: 300). + #[serde(default = "default_request_timeout_secs")] + pub request_timeout_secs: u64, + + /// Connection timeout in seconds (default: 10). + #[serde(default = "default_connect_timeout_secs")] + pub connect_timeout_secs: u64, + + /// Maximum idle connections in the pool (default: 10). + #[serde(default = "default_pool_max_idle")] + pub pool_max_idle: usize, + + /// Idle connection timeout in seconds (default: 60). + #[serde(default = "default_pool_idle_timeout_secs")] + pub pool_idle_timeout_secs: u64, +} + +fn default_request_timeout_secs() -> u64 { + 300 +} + +fn default_connect_timeout_secs() -> u64 { + 10 +} + +fn default_pool_max_idle() -> usize { + 10 +} + +fn default_pool_idle_timeout_secs() -> u64 { + 60 +} + +impl Default for OllamaClientConfig { + fn default() -> Self { + Self { + request_timeout_secs: default_request_timeout_secs(), + connect_timeout_secs: default_connect_timeout_secs(), + pool_max_idle: default_pool_max_idle(), + pool_idle_timeout_secs: default_pool_idle_timeout_secs(), + } + } +} + /// Top-level configuration for the Model Gateway service. #[derive(Debug, Clone, Deserialize)] pub struct Config { @@ -92,6 +139,10 @@ pub struct Config { /// Model routing configuration. #[serde(default)] pub routing: ModelRoutingConfig, + + /// Ollama HTTP client configuration. + #[serde(default)] + pub client: OllamaClientConfig, } fn default_host() -> String { @@ -114,6 +165,7 @@ impl Default for Config { ollama_url: default_ollama_url(), audit_addr: None, routing: ModelRoutingConfig::default(), + client: OllamaClientConfig::default(), } } } @@ -230,6 +282,48 @@ code = "codellama:7b" assert_eq!(count, 1); } + #[test] + fn test_client_config_defaults() { + let cc = OllamaClientConfig::default(); + assert_eq!(cc.request_timeout_secs, 300); + assert_eq!(cc.connect_timeout_secs, 10); + assert_eq!(cc.pool_max_idle, 10); + assert_eq!(cc.pool_idle_timeout_secs, 60); + } + + #[test] + fn test_client_config_from_toml() { + let dir = tempfile::tempdir().unwrap(); + let config_path = dir.path().join("gateway.toml"); + std::fs::write( + &config_path, + r#" +host = "0.0.0.0" +port = 9999 + +[client] +request_timeout_secs = 600 +connect_timeout_secs = 5 +pool_max_idle = 20 +pool_idle_timeout_secs = 120 +"#, + ) + .unwrap(); + + let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); + assert_eq!(config.client.request_timeout_secs, 600); + assert_eq!(config.client.connect_timeout_secs, 5); + assert_eq!(config.client.pool_max_idle, 20); + assert_eq!(config.client.pool_idle_timeout_secs, 120); + } + + #[test] + fn test_client_config_uses_defaults_when_omitted() { + let config = Config::default(); + assert_eq!(config.client.request_timeout_secs, 300); + assert_eq!(config.client.connect_timeout_secs, 10); + } + #[test] fn test_routing_from_toml_uses_defaults_when_omitted() { let dir = tempfile::tempdir().unwrap(); diff --git a/services/model-gateway/src/lib.rs b/services/model-gateway/src/lib.rs index 11bd8fe..81656a2 100644 --- a/services/model-gateway/src/lib.rs +++ b/services/model-gateway/src/lib.rs @@ -1,2 +1,3 @@ pub mod config; +pub mod ollama; pub mod service; diff --git a/services/model-gateway/src/main.rs b/services/model-gateway/src/main.rs index 6ca1a51..aea3acb 100644 --- a/services/model-gateway/src/main.rs +++ b/services/model-gateway/src/main.rs @@ -23,7 +23,7 @@ async fn main() -> anyhow::Result<()> { ); let addr = config.listen_addr().parse()?; - let service = ModelGatewayServiceImpl::new(config); + let service = ModelGatewayServiceImpl::new(config)?; tracing::info!(%addr, "Model Gateway listening"); diff --git a/services/model-gateway/src/ollama/client.rs b/services/model-gateway/src/ollama/client.rs new file mode 100644 index 0000000..d519f9d --- /dev/null +++ b/services/model-gateway/src/ollama/client.rs @@ -0,0 +1,570 @@ +use std::pin::Pin; +use std::time::Duration; + +use bytes::Bytes; +use futures::stream::Stream; +use futures::StreamExt; +use reqwest::Client; + +use super::error::OllamaError; +use super::types::*; +use crate::config::Config; + +/// Async HTTP client for the Ollama REST API. +/// +/// Wraps `reqwest::Client` with connection pooling, timeouts, and +/// typed methods for each Ollama endpoint. +pub struct OllamaClient { + client: Client, + base_url: String, +} + +impl OllamaClient { + /// Create a new client from the service configuration. + pub fn new(config: &Config) -> Result { + let cc = &config.client; + let client = Client::builder() + .timeout(Duration::from_secs(cc.request_timeout_secs)) + .connect_timeout(Duration::from_secs(cc.connect_timeout_secs)) + .pool_max_idle_per_host(cc.pool_max_idle) + .pool_idle_timeout(Duration::from_secs(cc.pool_idle_timeout_secs)) + .build()?; + + let base_url = config.ollama_url.trim_end_matches('/').to_string(); + Ok(Self { client, base_url }) + } + + /// POST /api/generate (non-streaming). + pub async fn generate( + &self, + model: &str, + prompt: &str, + options: Option, + ) -> Result { + let request = GenerateRequest { + model: model.to_string(), + prompt: prompt.to_string(), + stream: false, + options, + }; + + let resp = self + .client + .post(format!("{}/api/generate", self.base_url)) + .json(&request) + .send() + .await?; + + let resp = Self::check_status(resp).await?; + resp.json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// POST /api/generate (streaming). + /// + /// Returns a `Stream` of `GenerateStreamChunk` items parsed from NDJSON. + pub async fn generate_stream( + &self, + model: &str, + prompt: &str, + options: Option, + ) -> Result< + Pin> + Send>>, + OllamaError, + > { + let request = GenerateRequest { + model: model.to_string(), + prompt: prompt.to_string(), + stream: true, + options, + }; + + let resp = self + .client + .post(format!("{}/api/generate", self.base_url)) + .json(&request) + .send() + .await?; + + let resp = Self::check_status(resp).await?; + Ok(Self::ndjson_stream::(resp)) + } + + /// POST /api/chat (non-streaming). + pub async fn chat( + &self, + model: &str, + messages: Vec, + options: Option, + ) -> Result { + let request = ChatRequest { + model: model.to_string(), + messages, + stream: false, + options, + }; + + let resp = self + .client + .post(format!("{}/api/chat", self.base_url)) + .json(&request) + .send() + .await?; + + let resp = Self::check_status(resp).await?; + resp.json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// POST /api/embed. + pub async fn embed( + &self, + model: &str, + input: Vec, + ) -> Result { + let request = EmbedRequest { + model: model.to_string(), + input, + }; + + let resp = self + .client + .post(format!("{}/api/embed", self.base_url)) + .json(&request) + .send() + .await?; + + let resp = Self::check_status(resp).await?; + resp.json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// GET /api/tags — list all models available on the Ollama instance. + pub async fn list_models(&self) -> Result { + let resp = self + .client + .get(format!("{}/api/tags", self.base_url)) + .send() + .await?; + + let resp = Self::check_status(resp).await?; + resp.json::() + .await + .map_err(|e| OllamaError::Deserialization(e.to_string())) + } + + /// Check if Ollama is reachable. + pub async fn is_healthy(&self) -> bool { + self.list_models().await.is_ok() + } + + /// Parse an NDJSON response body into a typed stream. + /// + /// Ollama streams newline-delimited JSON. Each line is a complete + /// JSON object. This buffers bytes until newlines are found, then + /// deserializes each complete line. + fn ndjson_stream( + resp: reqwest::Response, + ) -> Pin> + Send>> { + let byte_stream = resp.bytes_stream(); + + type ByteStream = Pin> + Send>>; + + Box::pin(futures::stream::unfold( + (Box::pin(byte_stream) as ByteStream, Vec::::new()), + |(mut stream, mut buf): (ByteStream, Vec)| async move { + loop { + // Check if buffer contains a complete line. + if let Some(pos) = buf.iter().position(|&b| b == b'\n') { + let line = buf.drain(..=pos).collect::>(); + let line = &line[..line.len() - 1]; // strip newline + if line.is_empty() { + continue; // skip empty lines + } + let result = serde_json::from_slice::(line) + .map_err(|e| OllamaError::Deserialization(e.to_string())); + return Some((result, (stream, buf))); + } + + // Need more data from the stream. + match stream.next().await { + Some(Ok(chunk)) => { + buf.extend_from_slice(&chunk); + } + Some(Err(e)) => { + return Some((Err(OllamaError::Http(e)), (stream, buf))); + } + None => { + // Stream ended. Parse any remaining data in buffer. + if buf.is_empty() { + return None; + } + // Trim trailing whitespace + let trimmed = buf.trim_ascii(); + if trimmed.is_empty() { + return None; + } + let result = serde_json::from_slice::(trimmed) + .map_err(|e| OllamaError::Deserialization(e.to_string())); + buf.clear(); + return Some((result, (stream, buf))); + } + } + } + }, + )) + } + + /// Check response status and extract error message for non-2xx. + async fn check_status(resp: reqwest::Response) -> Result { + if resp.status().is_success() { + return Ok(resp); + } + let status = resp.status().as_u16(); + let message = resp + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + Err(OllamaError::Api { status, message }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn test_client(base_url: &str) -> OllamaClient { + let config = Config { + ollama_url: base_url.to_string(), + ..Config::default() + }; + OllamaClient::new(&config).unwrap() + } + + #[tokio::test] + async fn test_generate_success() { + let mock_server = MockServer::start().await; + let body = r#"{"model":"llama3.2:3b","response":"Hello world!","done":true,"done_reason":"stop","eval_count":5}"#; + + Mock::given(method("POST")) + .and(path("/api/generate")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let resp = client.generate("llama3.2:3b", "Hi", None).await.unwrap(); + assert_eq!(resp.response, "Hello world!"); + assert!(resp.done); + assert_eq!(resp.eval_count, Some(5)); + } + + #[tokio::test] + async fn test_generate_with_options() { + let mock_server = MockServer::start().await; + let body = r#"{"model":"llama3.2:3b","response":"test","done":true}"#; + + Mock::given(method("POST")) + .and(path("/api/generate")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let opts = GenerateOptions { + temperature: Some(0.7), + top_p: Some(0.9), + num_predict: Some(100), + stop: Some(vec!["STOP".to_string()]), + }; + let resp = client + .generate("llama3.2:3b", "Hi", Some(opts)) + .await + .unwrap(); + assert_eq!(resp.response, "test"); + } + + #[tokio::test] + async fn test_generate_stream_success() { + let mock_server = MockServer::start().await; + let ndjson = concat!( + r#"{"model":"llama3.2:3b","response":"Hello","done":false}"#, "\n", + r#"{"model":"llama3.2:3b","response":" world","done":false}"#, "\n", + r#"{"model":"llama3.2:3b","response":"!","done":false}"#, "\n", + r#"{"model":"llama3.2:3b","response":"","done":true,"done_reason":"stop","eval_count":3}"#, "\n", + ); + + Mock::given(method("POST")) + .and(path("/api/generate")) + .respond_with( + ResponseTemplate::new(200).set_body_raw(ndjson, "application/x-ndjson"), + ) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let stream = client + .generate_stream("llama3.2:3b", "Hi", None) + .await + .unwrap(); + + let chunks: Vec<_> = stream.collect::>().await; + assert_eq!(chunks.len(), 4); + assert_eq!(chunks[0].as_ref().unwrap().response, "Hello"); + assert_eq!(chunks[1].as_ref().unwrap().response, " world"); + assert_eq!(chunks[2].as_ref().unwrap().response, "!"); + assert!(chunks[3].as_ref().unwrap().done); + } + + #[tokio::test] + async fn test_generate_stream_single_done() { + let mock_server = MockServer::start().await; + let ndjson = r#"{"model":"llama3.2:3b","response":"","done":true,"done_reason":"stop"}"#; + + Mock::given(method("POST")) + .and(path("/api/generate")) + .respond_with( + ResponseTemplate::new(200).set_body_raw(ndjson, "application/x-ndjson"), + ) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let stream = client + .generate_stream("llama3.2:3b", "Hi", None) + .await + .unwrap(); + + let chunks: Vec<_> = stream.collect::>().await; + assert_eq!(chunks.len(), 1); + assert!(chunks[0].as_ref().unwrap().done); + } + + #[tokio::test] + async fn test_chat_success() { + let mock_server = MockServer::start().await; + let body = r#"{"model":"llama3.2:3b","message":{"role":"assistant","content":"Hi there!"},"done":true,"done_reason":"stop"}"#; + + Mock::given(method("POST")) + .and(path("/api/chat")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let msgs = vec![ChatMessage { + role: ChatRole::User, + content: "Hello".to_string(), + }]; + let resp = client.chat("llama3.2:3b", msgs, None).await.unwrap(); + assert_eq!(resp.message.content, "Hi there!"); + } + + #[tokio::test] + async fn test_chat_with_history() { + let mock_server = MockServer::start().await; + let body = r#"{"model":"llama3.2:3b","message":{"role":"assistant","content":"Fine, thanks!"},"done":true}"#; + + Mock::given(method("POST")) + .and(path("/api/chat")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let msgs = vec![ + ChatMessage { + role: ChatRole::System, + content: "Be helpful".to_string(), + }, + ChatMessage { + role: ChatRole::User, + content: "Hi".to_string(), + }, + ChatMessage { + role: ChatRole::Assistant, + content: "Hello!".to_string(), + }, + ChatMessage { + role: ChatRole::User, + content: "How are you?".to_string(), + }, + ]; + let resp = client.chat("llama3.2:3b", msgs, None).await.unwrap(); + assert_eq!(resp.message.content, "Fine, thanks!"); + } + + #[tokio::test] + async fn test_embed_success() { + let mock_server = MockServer::start().await; + let body = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2,0.3]]}"#; + + Mock::given(method("POST")) + .and(path("/api/embed")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let resp = client + .embed("nomic-embed-text", vec!["hello".to_string()]) + .await + .unwrap(); + assert_eq!(resp.embeddings.len(), 1); + assert_eq!(resp.embeddings[0].len(), 3); + } + + #[tokio::test] + async fn test_embed_multiple_inputs() { + let mock_server = MockServer::start().await; + let body = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2],[0.3,0.4]]}"#; + + Mock::given(method("POST")) + .and(path("/api/embed")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let resp = client + .embed( + "nomic-embed-text", + vec!["hello".to_string(), "world".to_string()], + ) + .await + .unwrap(); + assert_eq!(resp.embeddings.len(), 2); + } + + #[tokio::test] + async fn test_list_models_success() { + let mock_server = MockServer::start().await; + let body = r#"{"models":[{"name":"llama3.2:3b","model":"llama3.2:3b","size":1234567},{"name":"nomic-embed-text","model":"nomic-embed-text","size":987654}]}"#; + + Mock::given(method("GET")) + .and(path("/api/tags")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let resp = client.list_models().await.unwrap(); + assert_eq!(resp.models.len(), 2); + } + + #[tokio::test] + async fn test_list_models_empty() { + let mock_server = MockServer::start().await; + let body = r#"{"models":[]}"#; + + Mock::given(method("GET")) + .and(path("/api/tags")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let resp = client.list_models().await.unwrap(); + assert!(resp.models.is_empty()); + } + + #[tokio::test] + async fn test_is_healthy_success() { + let mock_server = MockServer::start().await; + let body = r#"{"models":[]}"#; + + Mock::given(method("GET")) + .and(path("/api/tags")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + assert!(client.is_healthy().await); + } + + #[tokio::test] + async fn test_is_healthy_failure() { + let mock_server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api/tags")) + .respond_with(ResponseTemplate::new(500).set_body_string("internal error")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + assert!(!client.is_healthy().await); + } + + #[tokio::test] + async fn test_api_error_404() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/api/generate")) + .respond_with(ResponseTemplate::new(404).set_body_string("model 'nonexistent' not found")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let err = client + .generate("nonexistent", "Hi", None) + .await + .unwrap_err(); + match err { + OllamaError::Api { status, message } => { + assert_eq!(status, 404); + assert!(message.contains("not found")); + } + other => panic!("expected Api error, got: {other:?}"), + } + } + + #[tokio::test] + async fn test_api_error_500() { + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/api/generate")) + .respond_with(ResponseTemplate::new(500).set_body_string("internal server error")) + .mount(&mock_server) + .await; + + let client = test_client(&mock_server.uri()); + let err = client + .generate("llama3.2:3b", "Hi", None) + .await + .unwrap_err(); + match err { + OllamaError::Api { status, .. } => assert_eq!(status, 500), + other => panic!("expected Api error, got: {other:?}"), + } + } + + #[tokio::test] + async fn test_base_url_trailing_slash() { + let mock_server = MockServer::start().await; + let body = r#"{"models":[]}"#; + + Mock::given(method("GET")) + .and(path("/api/tags")) + .respond_with(ResponseTemplate::new(200).set_body_raw(body, "application/json")) + .mount(&mock_server) + .await; + + // URL with trailing slash + let config = Config { + ollama_url: format!("{}/", mock_server.uri()), + ..Config::default() + }; + let client = OllamaClient::new(&config).unwrap(); + let resp = client.list_models().await.unwrap(); + assert!(resp.models.is_empty()); + } +} diff --git a/services/model-gateway/src/ollama/error.rs b/services/model-gateway/src/ollama/error.rs new file mode 100644 index 0000000..d37bcc9 --- /dev/null +++ b/services/model-gateway/src/ollama/error.rs @@ -0,0 +1,51 @@ +use thiserror::Error; + +/// Errors that can occur when communicating with the Ollama API. +#[derive(Debug, Error)] +pub enum OllamaError { + /// HTTP-level error (connection refused, timeout, DNS, etc.). + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + + /// Ollama returned a non-2xx status code. + #[error("Ollama API error (status {status}): {message}")] + Api { status: u16, message: String }, + + /// Failed to deserialize Ollama JSON response. + #[error("deserialization error: {0}")] + Deserialization(String), + + /// Stream terminated unexpectedly without a done:true chunk. + #[error("stream ended unexpectedly")] + StreamIncomplete, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_display_api() { + let err = OllamaError::Api { + status: 404, + message: "model not found".to_string(), + }; + let msg = format!("{err}"); + assert!(msg.contains("404")); + assert!(msg.contains("model not found")); + } + + #[test] + fn test_error_display_deserialization() { + let err = OllamaError::Deserialization("unexpected EOF".to_string()); + let msg = format!("{err}"); + assert!(msg.contains("unexpected EOF")); + } + + #[test] + fn test_error_display_stream_incomplete() { + let err = OllamaError::StreamIncomplete; + let msg = format!("{err}"); + assert!(msg.contains("stream ended unexpectedly")); + } +} diff --git a/services/model-gateway/src/ollama/mod.rs b/services/model-gateway/src/ollama/mod.rs new file mode 100644 index 0000000..96eaed4 --- /dev/null +++ b/services/model-gateway/src/ollama/mod.rs @@ -0,0 +1,6 @@ +pub mod client; +pub mod error; +pub mod types; + +pub use client::OllamaClient; +pub use error::OllamaError; diff --git a/services/model-gateway/src/ollama/types.rs b/services/model-gateway/src/ollama/types.rs new file mode 100644 index 0000000..e2bba54 --- /dev/null +++ b/services/model-gateway/src/ollama/types.rs @@ -0,0 +1,301 @@ +use serde::{Deserialize, Serialize}; + +// --- /api/generate --- + +#[derive(Debug, Serialize)] +pub struct GenerateRequest { + pub model: String, + pub prompt: String, + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +#[derive(Debug, Serialize)] +pub struct GenerateOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_predict: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, +} + +/// Full response from /api/generate with stream:false. +#[derive(Debug, Deserialize)] +pub struct GenerateResponse { + pub model: String, + pub response: String, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub prompt_eval_count: Option, +} + +/// Single chunk from /api/generate with stream:true (NDJSON). +#[derive(Debug, Deserialize)] +pub struct GenerateStreamChunk { + pub model: String, + pub response: String, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub prompt_eval_count: Option, +} + +// --- /api/chat --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ChatRole { + System, + User, + Assistant, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: ChatRole, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct ChatRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ChatResponse { + pub model: String, + pub message: ChatMessage, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub total_duration: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub prompt_eval_count: Option, +} + +// --- /api/embed --- + +#[derive(Debug, Serialize)] +pub struct EmbedRequest { + pub model: String, + pub input: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct EmbedResponse { + pub model: String, + pub embeddings: Vec>, +} + +// --- /api/tags (list models) --- + +#[derive(Debug, Deserialize)] +pub struct ListModelsResponse { + pub models: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct ModelInfo { + pub name: String, + pub model: String, + #[serde(default)] + pub size: u64, + #[serde(default)] + pub digest: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_request_serialization() { + let req = GenerateRequest { + model: "llama3.2:3b".to_string(), + prompt: "Hello".to_string(), + stream: false, + options: None, + }; + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["model"], "llama3.2:3b"); + assert_eq!(json["prompt"], "Hello"); + assert_eq!(json["stream"], false); + assert!(json.get("options").is_none()); + } + + #[test] + fn test_generate_request_serialization_with_options() { + let req = GenerateRequest { + model: "llama3.2:3b".to_string(), + prompt: "Hello".to_string(), + stream: false, + options: Some(GenerateOptions { + temperature: Some(0.7), + top_p: None, + num_predict: Some(100), + stop: Some(vec!["<|end|>".to_string()]), + }), + }; + let json = serde_json::to_value(&req).unwrap(); + let opts = &json["options"]; + assert!((opts["temperature"].as_f64().unwrap() - 0.7).abs() < 0.001); + assert!(opts.get("top_p").is_none()); + assert_eq!(opts["num_predict"], 100); + assert_eq!(opts["stop"][0], "<|end|>"); + } + + #[test] + fn test_generate_response_deserialization() { + let json = r#"{ + "model": "llama3.2:3b", + "response": "Hello world!", + "done": true, + "done_reason": "stop", + "total_duration": 1234567890, + "eval_count": 42, + "prompt_eval_count": 10 + }"#; + let resp: GenerateResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.model, "llama3.2:3b"); + assert_eq!(resp.response, "Hello world!"); + assert!(resp.done); + assert_eq!(resp.done_reason.as_deref(), Some("stop")); + assert_eq!(resp.eval_count, Some(42)); + } + + #[test] + fn test_generate_response_missing_optional_fields() { + let json = r#"{ + "model": "llama3.2:3b", + "response": "Hello", + "done": true + }"#; + let resp: GenerateResponse = serde_json::from_str(json).unwrap(); + assert!(resp.done_reason.is_none()); + assert!(resp.total_duration.is_none()); + assert!(resp.eval_count.is_none()); + } + + #[test] + fn test_generate_stream_chunk_deserialization() { + let json = r#"{"model":"llama3.2:3b","response":"Hello","done":false}"#; + let chunk: GenerateStreamChunk = serde_json::from_str(json).unwrap(); + assert_eq!(chunk.response, "Hello"); + assert!(!chunk.done); + } + + #[test] + fn test_generate_stream_chunk_final() { + let json = r#"{ + "model": "llama3.2:3b", + "response": "", + "done": true, + "done_reason": "stop", + "eval_count": 15 + }"#; + let chunk: GenerateStreamChunk = serde_json::from_str(json).unwrap(); + assert!(chunk.done); + assert_eq!(chunk.done_reason.as_deref(), Some("stop")); + assert_eq!(chunk.eval_count, Some(15)); + } + + #[test] + fn test_chat_request_serialization() { + let req = ChatRequest { + model: "llama3.2:3b".to_string(), + messages: vec![ + ChatMessage { + role: ChatRole::System, + content: "You are helpful.".to_string(), + }, + ChatMessage { + role: ChatRole::User, + content: "Hello".to_string(), + }, + ], + stream: false, + options: None, + }; + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["messages"].as_array().unwrap().len(), 2); + assert_eq!(json["messages"][0]["role"], "system"); + assert_eq!(json["messages"][1]["role"], "user"); + } + + #[test] + fn test_chat_role_serialization() { + let json = serde_json::to_value(ChatRole::Assistant).unwrap(); + assert_eq!(json, "assistant"); + } + + #[test] + fn test_chat_response_deserialization() { + let json = r#"{ + "model": "llama3.2:3b", + "message": {"role": "assistant", "content": "Hi there!"}, + "done": true, + "done_reason": "stop", + "eval_count": 5 + }"#; + let resp: ChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.message.content, "Hi there!"); + assert!(resp.done); + } + + #[test] + fn test_embed_request_serialization() { + let req = EmbedRequest { + model: "nomic-embed-text".to_string(), + input: vec!["hello".to_string(), "world".to_string()], + }; + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["model"], "nomic-embed-text"); + assert_eq!(json["input"].as_array().unwrap().len(), 2); + } + + #[test] + fn test_embed_response_deserialization() { + let json = r#"{ + "model": "nomic-embed-text", + "embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + }"#; + let resp: EmbedResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.embeddings.len(), 2); + assert_eq!(resp.embeddings[0].len(), 3); + } + + #[test] + fn test_list_models_response_deserialization() { + let json = r#"{ + "models": [ + {"name": "llama3.2:3b", "model": "llama3.2:3b", "size": 1234567, "digest": "abc123"}, + {"name": "nomic-embed-text", "model": "nomic-embed-text", "size": 987654} + ] + }"#; + let resp: ListModelsResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.models.len(), 2); + assert_eq!(resp.models[0].name, "llama3.2:3b"); + assert_eq!(resp.models[0].digest.as_deref(), Some("abc123")); + assert!(resp.models[1].digest.is_none()); + } +} diff --git a/services/model-gateway/src/service.rs b/services/model-gateway/src/service.rs index bb914f4..b594196 100644 --- a/services/model-gateway/src/service.rs +++ b/services/model-gateway/src/service.rs @@ -6,15 +6,19 @@ use llm_multiverse_proto::llm_multiverse::v1::{ use tonic::{Request, Response, Status}; use crate::config::Config; +use crate::ollama::OllamaClient; /// Implementation of the ModelGatewayService gRPC trait. pub struct ModelGatewayServiceImpl { config: Config, + #[allow(dead_code)] + ollama: OllamaClient, } impl ModelGatewayServiceImpl { - pub fn new(config: Config) -> Self { - Self { config } + pub fn new(config: Config) -> Result { + let ollama = OllamaClient::new(&config)?; + Ok(Self { config, ollama }) } } @@ -90,7 +94,7 @@ mod tests { #[tokio::test] async fn test_is_model_ready_all_models() { - let svc = ModelGatewayServiceImpl::new(test_config()); + let svc = ModelGatewayServiceImpl::new(test_config()).unwrap(); let req = Request::new(IsModelReadyRequest { model_name: None }); let resp = svc.is_model_ready(req).await.unwrap().into_inner(); @@ -105,7 +109,7 @@ mod tests { #[tokio::test] async fn test_is_model_ready_specific_model_found() { - let svc = ModelGatewayServiceImpl::new(test_config()); + let svc = ModelGatewayServiceImpl::new(test_config()).unwrap(); let req = Request::new(IsModelReadyRequest { model_name: Some("llama3.2:3b".to_string()), }); @@ -118,7 +122,7 @@ mod tests { #[tokio::test] async fn test_is_model_ready_specific_model_not_found() { - let svc = ModelGatewayServiceImpl::new(test_config()); + let svc = ModelGatewayServiceImpl::new(test_config()).unwrap(); let req = Request::new(IsModelReadyRequest { model_name: Some("nonexistent:latest".to_string()), }); @@ -137,7 +141,7 @@ mod tests { .aliases .insert("code".into(), "codellama:7b".into()); - let svc = ModelGatewayServiceImpl::new(config); + let svc = ModelGatewayServiceImpl::new(config).unwrap(); let req = Request::new(IsModelReadyRequest { model_name: Some("codellama:7b".to_string()), }); @@ -148,7 +152,7 @@ mod tests { #[tokio::test] async fn test_stream_inference_unimplemented() { - let svc = ModelGatewayServiceImpl::new(test_config()); + let svc = ModelGatewayServiceImpl::new(test_config()).unwrap(); let req = Request::new(StreamInferenceRequest { params: None }); let result = svc.stream_inference(req).await; @@ -158,7 +162,7 @@ mod tests { #[tokio::test] async fn test_inference_unimplemented() { - let svc = ModelGatewayServiceImpl::new(test_config()); + let svc = ModelGatewayServiceImpl::new(test_config()).unwrap(); let req = Request::new(InferenceRequest { params: None }); let result = svc.inference(req).await; @@ -168,7 +172,7 @@ mod tests { #[tokio::test] async fn test_generate_embedding_unimplemented() { - let svc = ModelGatewayServiceImpl::new(test_config()); + let svc = ModelGatewayServiceImpl::new(test_config()).unwrap(); let req = Request::new(GenerateEmbeddingRequest { context: None, text: "test".into(),