diff --git a/implementation-plans/issue-031.md b/implementation-plans/issue-031.md new file mode 100644 index 0000000..99a52ce --- /dev/null +++ b/implementation-plans/issue-031.md @@ -0,0 +1,445 @@ +# Implementation Plan — Issue #31: Implement extraction step (model call for relevant segment) + +## Metadata + +| Field | Value | +|---|---| +| Issue | [#31](https://git.shahondin1624.de/llm-multiverse/llm-multiverse/issues/31) | +| Title | Implement extraction step (model call for relevant segment) | +| Milestone | Phase 4: Memory Service | +| Labels | | +| Status | `IMPLEMENTING` | +| Language | Rust | +| Related Plans | [issue-011.md](issue-011.md), [issue-012.md](issue-012.md), [issue-029.md](issue-029.md), [issue-030.md](issue-030.md) | +| Blocked by | #30 (completed) | + +## Acceptance Criteria + +- [ ] Post-retrieval model call via Model Gateway `Inference` +- [ ] Prompt extracts relevant segment from memory content given query +- [ ] Output is a focused, concise segment (not full memory) +- [ ] Configurable: can be disabled for low-latency queries +- [ ] Results tagged with extraction confidence + +## Architecture Analysis + +### Service Context + +This issue belongs to the **Memory Service** (Rust). It implements the **extraction step** described in the architecture document: + +> When a corpus is read in full, a lightweight model call extracts only the segment relevant to the query. This extracted segment — not the full corpus — is what enters agent context and what is cached. + +The extraction step runs **after** the 4-stage retrieval pipeline (issue #30) completes. For each candidate that survived all stages, a unary `Inference` call is made to the Model Gateway with a prompt instructing the model to extract only the relevant segment from the corpus, given the original query. This is a lightweight model call routed via `TaskComplexity::SIMPLE` (3B/7B model). + +**Affected gRPC endpoints:** +- `QueryMemory` (server-streaming) -- the extraction step is integrated into the response pipeline between retrieval and streaming results back to the client. + +**Proto messages used:** +- `InferenceRequest` / `InferenceResponse` (from `model_gateway.proto`) -- the Memory Service becomes a gRPC client of the Model Gateway's `Inference` RPC (in addition to the existing `GenerateEmbedding` client). +- `InferenceParams` with `TaskComplexity::SIMPLE` -- extraction is a lightweight task. +- `QueryMemoryRequest` -- needs a mechanism to signal whether extraction should be enabled or disabled (for low-latency queries). The proto already has `memory_type` and `limit` fields but no extraction toggle. Options: (a) add an `optional bool skip_extraction` field to the proto, or (b) use a config-only toggle. Given the acceptance criterion "configurable: can be disabled for low-latency queries", a per-request toggle in `QueryMemoryRequest` is needed. This requires a proto change. +- `QueryMemoryResponse` -- the `cached_extracted_segment` field (proto line 62) already exists for extracted segments. Currently always `None`. The extraction result populates this field. + +### Existing Patterns + +- **Model Gateway gRPC client:** The Memory Service already has an `EmbeddingClient` that wraps `ModelGatewayServiceClient` (see `services/memory/src/embedding/mod.rs:116-118`). The `Inference` RPC is on the same `ModelGatewayServiceClient`, so the existing channel can be reused. However, the current `EmbeddingClient` only exposes `GenerateEmbedding`. The extraction client needs access to `Inference`. +- **Builder pattern for clients:** `MemoryServiceImpl::with_embedding_client()` at `services/memory/src/service.rs:43-46` follows the builder pattern from the secrets service. The extraction client can follow the same pattern, or the existing `EmbeddingClient` can be extended to also provide an inference method. +- **Config pattern:** `RetrievalConfig` in `services/memory/src/config.rs:5-18` uses `#[serde(default)]` with named default functions. The extraction config should follow the same pattern. +- **Pipeline execution:** `retrieval::pipeline::execute_pipeline()` at `services/memory/src/retrieval/pipeline.rs:32-118` is synchronous (runs inside `with_connection()` closures). The extraction step requires async gRPC calls, so it must run **after** the pipeline, at the service layer in `service.rs` (similar to how embedding generation runs before the pipeline). +- **Streaming response:** Results are streamed via `tokio::sync::mpsc::channel` at `services/memory/src/service.rs:145-161`. Extraction can run per-candidate before sending each response, or all extractions can run first, then stream. + +### Dependencies + +- **Model Gateway `Inference` RPC** -- the gateway service is not yet implemented, but the proto stubs are generated. The extraction client will call `ModelGatewayServiceClient::inference()`. +- **Proto change** -- `QueryMemoryRequest` needs a `skip_extraction` field for per-request toggle. +- **No new crate dependencies** -- `tonic`, `prost`, and the proto stubs already provide everything needed. + +## Implementation Steps + +### 1. Types & Configuration + +**Update `proto/llm_multiverse/v1/memory.proto` -- Add extraction toggle to `QueryMemoryRequest`:** + +```protobuf +message QueryMemoryRequest { + SessionContext context = 1; + string query = 2; + string memory_type = 3; + uint32 limit = 4; + // When true, skip the extraction step for lower latency. + bool skip_extraction = 5; +} +``` + +**Update `proto/llm_multiverse/v1/memory.proto` -- Add extraction confidence to `QueryMemoryResponse`:** + +```protobuf +message QueryMemoryResponse { + uint32 rank = 1; + MemoryEntry entry = 2; + float cosine_similarity = 3; + bool is_cached = 4; + optional string cached_extracted_segment = 5; + // Confidence of the extraction (0.0-1.0). Only set when extraction was performed. + optional float extraction_confidence = 6; +} +``` + +After proto changes, regenerate Rust stubs by running the build. + +**Add extraction configuration to `services/memory/src/config.rs`:** + +```rust +/// Configuration for the post-retrieval extraction step. +#[derive(Debug, Clone, Deserialize)] +pub struct ExtractionConfig { + /// Whether extraction is enabled globally (default: true). + /// When false, extraction is never performed regardless of per-request settings. + #[serde(default = "default_extraction_enabled")] + pub enabled: bool, + + /// Maximum tokens for the extraction model response (default: 256). + #[serde(default = "default_extraction_max_tokens")] + pub max_tokens: u32, + + /// Temperature for extraction inference (default: 0.1, low for deterministic extraction). + #[serde(default = "default_extraction_temperature")] + pub temperature: f32, +} +``` + +Add `extraction: ExtractionConfig` field to the `Config` struct with `#[serde(default)]`. + +**Define extraction types in a new `services/memory/src/extraction/mod.rs`:** + +```rust +/// Result of extracting a relevant segment from a memory corpus. +#[derive(Debug, Clone)] +pub struct ExtractionResult { + /// The memory ID this extraction belongs to. + pub memory_id: String, + /// The extracted relevant segment. + pub segment: String, + /// Confidence score (0.0-1.0) parsed from the model's response. + pub confidence: f32, +} + +/// Errors specific to the extraction step. +#[derive(Debug, thiserror::Error)] +pub enum ExtractionError { + /// The Model Gateway Inference call failed. + #[error("inference call failed: {0}")] + InferenceFailed(#[from] tonic::Status), + + /// Connection to Model Gateway failed. + #[error("connection error: {0}")] + ConnectionError(#[from] tonic::transport::Error), + + /// Failed to parse the model's extraction response. + #[error("failed to parse extraction response: {0}")] + ParseError(String), +} +``` + +### 2. Core Logic + +**Create `services/memory/src/extraction/mod.rs` -- Extraction client:** + +```rust +use llm_multiverse_proto::llm_multiverse::v1::{ + model_gateway_service_client::ModelGatewayServiceClient, + InferenceRequest, InferenceParams, SessionContext, TaskComplexity, +}; +use tonic::transport::Channel; + +/// Client for performing post-retrieval extraction via the Model Gateway. +pub struct ExtractionClient { + client: ModelGatewayServiceClient, + config: ExtractionConfig, +} + +impl ExtractionClient { + /// Connect to the Model Gateway at the given endpoint. + pub async fn connect( + endpoint: &str, + config: ExtractionConfig, + ) -> Result; + + /// Extract the relevant segment from a memory corpus given a query. + /// + /// Calls the Model Gateway `Inference` RPC with a structured prompt + /// instructing the model to: + /// 1. Read the corpus text + /// 2. Identify the segment most relevant to the query + /// 3. Return ONLY that segment, verbatim or lightly paraphrased + /// 4. Include a confidence score (0.0-1.0) + /// + /// Uses `TaskComplexity::SIMPLE` for lightweight model routing. + pub async fn extract( + &self, + context: &SessionContext, + query: &str, + corpus: &str, + memory_name: &str, + ) -> Result; + + /// Extract segments for multiple candidates. + /// + /// Processes candidates sequentially to avoid overloading the gateway. + /// On failure for a single candidate, that candidate's extraction is + /// skipped (returns the full corpus as fallback) and processing continues. + pub async fn extract_batch( + &self, + context: &SessionContext, + query: &str, + candidates: &[RetrievalCandidate], + ) -> Vec; +} +``` + +**Extraction prompt template (in `services/memory/src/extraction/prompt.rs`):** + +```rust +/// Build the extraction prompt for a single memory entry. +/// +/// The prompt instructs the model to extract only the relevant segment +/// from the corpus, given the query context. Output format is structured +/// to enable reliable parsing. +pub fn build_extraction_prompt( + query: &str, + corpus: &str, + memory_name: &str, +) -> String; +``` + +The prompt should follow this structure: +``` +Given the following search query and memory content, extract ONLY the segment +of the content that is most relevant to the query. Return a concise, focused +excerpt — not the full content. + +Query: {query} +Memory name: {memory_name} + +Content: +--- +{corpus} +--- + +Respond in this exact format: +SEGMENT: +CONFIDENCE: <0.0-1.0 score indicating relevance> +``` + +**Response parser (in `services/memory/src/extraction/prompt.rs`):** + +```rust +/// Parse the model's extraction response into segment and confidence. +/// +/// Expects the format: +/// SEGMENT: +/// CONFIDENCE: +/// +/// Falls back gracefully: +/// - If CONFIDENCE is missing or unparseable, defaults to 0.5 +/// - If SEGMENT is missing, returns the full response as the segment +pub fn parse_extraction_response( + response: &str, + memory_id: &str, +) -> ExtractionResult; +``` + +**Key implementation details for `extract()`:** + +1. Build the prompt via `build_extraction_prompt()`. +2. Construct `InferenceRequest` with: + - `context`: the session context + - `prompt`: the extraction prompt + - `task_complexity`: `TaskComplexity::SIMPLE` (lightweight model) + - `max_tokens`: from `ExtractionConfig::max_tokens` + - `temperature`: from `ExtractionConfig::temperature` +3. Call `self.client.clone().inference(request).await`. +4. Parse the response text via `parse_extraction_response()`. +5. Return `ExtractionResult`. + +**Key implementation details for `extract_batch()`:** + +1. Iterate over candidates sequentially. +2. For each candidate, call `self.extract()`. +3. On error, log a warning and produce a fallback result with the full corpus and confidence 0.0. +4. Return all results (one per candidate). + +### 3. gRPC Handler Wiring + +**Update `services/memory/src/service.rs` -- Add extraction client and integrate into `query_memory`:** + +```rust +pub struct MemoryServiceImpl { + db: Arc, + embedding_client: Option>>, + extraction_client: Option>>, + retrieval_config: RetrievalConfig, + extraction_config: ExtractionConfig, +} + +impl MemoryServiceImpl { + pub fn new( + db: Arc, + retrieval_config: RetrievalConfig, + extraction_config: ExtractionConfig, + ) -> Self; + + pub fn with_embedding_client(mut self, client: EmbeddingClient) -> Self; + + /// Attach an extraction client for post-retrieval segment extraction. + pub fn with_extraction_client(mut self, client: ExtractionClient) -> Self; +} +``` + +**Update `query_memory` handler logic (in `services/memory/src/service.rs`):** + +After the retrieval pipeline returns candidates: + +1. Check if extraction should run: + - `extraction_config.enabled` is true AND + - `req.skip_extraction` is false AND + - `extraction_client` is `Some` +2. If extraction runs: + - Lock extraction client + - Call `extract_batch()` with query and candidates + - Drop the lock + - Build `QueryMemoryResponse` with `cached_extracted_segment = Some(result.segment)` and `extraction_confidence = Some(result.confidence)` +3. If extraction is skipped: + - Build `QueryMemoryResponse` with `cached_extracted_segment = None` and `extraction_confidence = None` +4. Stream responses as before. + +### 4. Service Integration + +**Update `services/memory/src/main.rs` -- Connect extraction client at startup:** + +```rust +// After embedding client setup: +let extraction_config = config.extraction.clone(); +if extraction_config.enabled { + if let Some(ref endpoint) = config.embedding_endpoint { + // Reuse the same Model Gateway endpoint for both embedding and inference + match ExtractionClient::connect(endpoint, extraction_config.clone()).await { + Ok(client) => { + tracing::info!(endpoint = %endpoint, "Connected to Model Gateway for extraction"); + memory_service = memory_service.with_extraction_client(client); + } + Err(e) => { + tracing::warn!( + endpoint = %endpoint, + error = %e, + "Model Gateway unavailable for extraction — extraction disabled" + ); + } + } + } +} +``` + +The extraction client connects to the **same** Model Gateway endpoint as the embedding client. Both use `ModelGatewayServiceClient` but call different RPCs (`GenerateEmbedding` vs `Inference`). + +**Error mapping:** `ExtractionError` variants map to gRPC status codes: +- `ExtractionError::InferenceFailed(_)` -> logged as warning, extraction skipped (fallback to no extraction) +- `ExtractionError::ConnectionError(_)` -> logged as warning, extraction skipped +- `ExtractionError::ParseError(_)` -> logged as warning, fallback to full corpus + +Extraction errors are **non-fatal** — if extraction fails for a candidate, the response is returned without an extracted segment. This ensures the retrieval pipeline is not blocked by extraction failures. + +### 5. Tests + +**Unit tests in `services/memory/src/extraction/prompt.rs`:** + +| Test Case | Description | +|---|---| +| `test_build_extraction_prompt_contains_query` | Prompt contains the query text | +| `test_build_extraction_prompt_contains_corpus` | Prompt contains the corpus text | +| `test_build_extraction_prompt_contains_memory_name` | Prompt contains the memory name | +| `test_parse_extraction_response_valid` | Correctly parses `SEGMENT: ... CONFIDENCE: 0.8` format | +| `test_parse_extraction_response_missing_confidence` | Missing confidence defaults to 0.5 | +| `test_parse_extraction_response_invalid_confidence` | Unparseable confidence defaults to 0.5 | +| `test_parse_extraction_response_missing_segment` | Missing SEGMENT prefix returns full response as segment | +| `test_parse_extraction_response_multiline_segment` | Multi-line segment text is captured correctly | +| `test_parse_extraction_response_empty` | Empty response returns empty segment with confidence 0.0 | + +**Unit tests in `services/memory/src/extraction/mod.rs`:** + +| Test Case | Description | +|---|---| +| `test_extraction_config_defaults` | Default config has `enabled=true`, `max_tokens=256`, `temperature=0.1` | +| `test_extraction_error_display` | Error messages are human-readable | +| `test_extraction_result_fields` | `ExtractionResult` carries memory_id, segment, and confidence | + +**Integration tests using mock Model Gateway (in `services/memory/src/extraction/mod.rs`):** + +| Test Case | Description | +|---|---| +| `test_extract_single_success` | Mock gateway returns extraction text, verify parsed result | +| `test_extract_single_gateway_error` | Mock gateway returns error status, verify `ExtractionError::InferenceFailed` | +| `test_extract_batch_all_success` | All candidates extracted successfully | +| `test_extract_batch_partial_failure` | One candidate fails, others succeed, failed one gets fallback | + +**Service-level tests in `services/memory/src/service.rs`:** + +| Test Case | Description | +|---|---| +| `test_query_memory_with_extraction` | Full flow: DB populated, mock gateway for embedding + inference, extraction runs and populates `cached_extracted_segment` | +| `test_query_memory_skip_extraction` | `skip_extraction=true` in request, verify `cached_extracted_segment` is None | +| `test_query_memory_extraction_disabled_config` | `extraction.enabled=false` in config, verify extraction is skipped | +| `test_query_memory_no_extraction_client` | No extraction client configured, verify results returned without extraction | + +**Mocking strategy:** +- Extend the mock Model Gateway server (already used in `services/memory/src/embedding/mod.rs:293-387` and `services/memory/src/service.rs:416-464`) to also handle `Inference` RPC. The mock `inference()` method returns a canned response in `SEGMENT: ... CONFIDENCE: ...` format. +- Use `DuckDbManager::in_memory()` for all DB operations. +- Use `MockEmbeddingGenerator` for embedding generation in pipeline tests. + +### Cargo Dependencies + +No new crate dependencies required. All functionality is available via: +- `tonic` (gRPC client for `Inference` RPC) +- `prost` / `llm-multiverse-proto` (proto stubs including `InferenceRequest`, `InferenceResponse`) +- `thiserror` (error types) + +### Trait Implementations + +No new trait implementations required. The `ExtractionClient` is a concrete struct (not trait-based) since it is only used directly by the service layer. If mocking is needed for tests, the mock Model Gateway server approach (already established) is used instead of a trait. + +### Error Types + +- `ExtractionError` -- enum covering inference, connection, and parse errors (see Types section above) + +## Files to Create/Modify + +| File | Action | Purpose | +|---|---|---| +| `proto/llm_multiverse/v1/memory.proto` | Modify | Add `skip_extraction` field to `QueryMemoryRequest`; add `extraction_confidence` field to `QueryMemoryResponse` | +| `services/memory/src/lib.rs` | Modify | Add `pub mod extraction;` | +| `services/memory/src/extraction/mod.rs` | Create | `ExtractionClient`, `ExtractionResult`, `ExtractionError`, `ExtractionConfig`, integration tests | +| `services/memory/src/extraction/prompt.rs` | Create | `build_extraction_prompt()`, `parse_extraction_response()`, prompt template, unit tests | +| `services/memory/src/config.rs` | Modify | Add `ExtractionConfig` struct with `enabled`, `max_tokens`, `temperature`; add `extraction` field to `Config` | +| `services/memory/src/service.rs` | Modify | Add `extraction_client` and `extraction_config` to `MemoryServiceImpl`; add `with_extraction_client()` builder; integrate extraction into `query_memory` handler | +| `services/memory/src/main.rs` | Modify | Connect `ExtractionClient` at startup if configured; pass `ExtractionConfig` to `MemoryServiceImpl` | + +## Risks and Edge Cases + +- **Model Gateway not yet implemented:** The Model Gateway `Inference` RPC does not exist yet (only the proto is defined). The extraction client can only be tested with a mock gateway server. The memory service must start cleanly without it (extraction disabled gracefully). +- **Prompt injection via corpus content:** Memory corpus text could contain adversarial content that tries to hijack the extraction prompt. Mitigation: the extraction prompt clearly delineates the corpus with delimiters (`---`), and the model is instructed to only extract, not follow instructions within the corpus. The architecture document notes that external-sourced memories are tagged with `MEMORY_PROVENANCE_EXTERNAL`, so the extraction prompt could include a warning for external content. However, since the extraction model is lightweight (3B/7B), it may be more susceptible to injection. Consider adding a `[UNTRUSTED CONTENT]` marker around external-provenance corpora in the prompt. +- **Model response format parsing:** LLMs may not always follow the exact `SEGMENT: ... CONFIDENCE: ...` format. The parser must be robust: fall back to treating the entire response as the segment if the format is not matched, and default confidence to 0.5. This ensures extraction never blocks the pipeline. +- **Latency impact:** Each extraction call adds inference latency (potentially 100-500ms per candidate for a 3B model). For 5 candidates, this could add 0.5-2.5 seconds. Mitigation: (a) the `skip_extraction` toggle allows low-latency queries, (b) extraction calls could be parallelized with `tokio::join!` or `futures::future::join_all` since they are independent. Start with sequential and optimize if needed. +- **Empty corpus:** If a candidate has an empty corpus, extraction is meaningless. Skip extraction for such candidates and return `None` for `cached_extracted_segment`. +- **Very large corpus:** If the corpus exceeds the model's context window, the extraction prompt will be truncated by the gateway. The extraction prompt should place the query and instructions at the beginning (before the corpus) to ensure they are not lost. The `max_tokens` limit on the response prevents excessively long extractions. +- **Extraction confidence calibration:** The confidence score is self-reported by the model and may not be well-calibrated. Consumers of the confidence score should treat it as a rough signal, not a precise metric. Consider clamping to [0.0, 1.0] range in the parser. +- **Concurrency with embedding client:** Both the extraction client and embedding client connect to the same Model Gateway. Each uses its own `ModelGatewayServiceClient` instance. `tonic::transport::Channel` internally manages connection pooling, so concurrent calls are handled correctly without additional synchronization beyond the per-client `Mutex`. +- **Proto backward compatibility:** Adding `skip_extraction` (field 5) and `extraction_confidence` (field 6) to existing messages is backward-compatible in protobuf (new fields have default values). Existing clients that don't set `skip_extraction` will get `false` (extraction enabled by default), which is the desired behavior. + +## Deviation Log + +_(Filled during implementation if deviations from plan occur)_ + +| Deviation | Reason | +|---|---| +| `ExtractionResult` placed in `extraction/prompt.rs` instead of `extraction/mod.rs` | Collocating the result struct with the prompt builder and parser that produce it improves cohesion | +| `extract()` takes an additional `memory_id` parameter not in the plan signature | Needed to populate `ExtractionResult.memory_id` without requiring the caller to do it | diff --git a/services/memory/src/config.rs b/services/memory/src/config.rs index 08ec3a7..cde628d 100644 --- a/services/memory/src/config.rs +++ b/services/memory/src/config.rs @@ -29,6 +29,45 @@ fn default_relevance_threshold() -> f32 { 0.3 } +/// Configuration for the post-retrieval extraction step. +#[derive(Debug, Clone, Deserialize)] +pub struct ExtractionConfig { + /// Whether extraction is enabled globally (default: true). + /// When false, extraction is never performed regardless of per-request settings. + #[serde(default = "default_extraction_enabled")] + pub enabled: bool, + + /// Maximum tokens for the extraction model response (default: 256). + #[serde(default = "default_extraction_max_tokens")] + pub max_tokens: u32, + + /// Temperature for extraction inference (default: 0.1, low for deterministic extraction). + #[serde(default = "default_extraction_temperature")] + pub temperature: f32, +} + +fn default_extraction_enabled() -> bool { + true +} + +fn default_extraction_max_tokens() -> u32 { + 256 +} + +fn default_extraction_temperature() -> f32 { + 0.1 +} + +impl Default for ExtractionConfig { + fn default() -> Self { + Self { + enabled: default_extraction_enabled(), + max_tokens: default_extraction_max_tokens(), + temperature: default_extraction_temperature(), + } + } +} + impl Default for RetrievalConfig { fn default() -> Self { Self { @@ -89,6 +128,9 @@ pub struct Config { /// Configuration for provenance tagging and poisoning protection. #[serde(default)] pub provenance: ProvenanceConfig, + /// Configuration for the post-retrieval extraction step. + #[serde(default)] + pub extraction: ExtractionConfig, } fn default_host() -> String { @@ -113,6 +155,7 @@ impl Default for Config { audit_addr: None, retrieval: RetrievalConfig::default(), provenance: ProvenanceConfig::default(), + extraction: ExtractionConfig::default(), } } } @@ -240,6 +283,14 @@ relevance_threshold = 0.5 assert!(pc.sanitization_enabled); } + #[test] + fn test_extraction_config_defaults() { + let ec = ExtractionConfig::default(); + assert!(ec.enabled); + assert_eq!(ec.max_tokens, 256); + assert!((ec.temperature - 0.1).abs() < f32::EPSILON); + } + #[test] fn test_provenance_config_from_toml() { let dir = tempfile::tempdir().unwrap(); @@ -263,6 +314,31 @@ sanitization_enabled = false assert!(!config.provenance.sanitization_enabled); } + #[test] + fn test_extraction_config_from_toml() { + let dir = tempfile::tempdir().unwrap(); + let config_path = dir.path().join("memory.toml"); + std::fs::write( + &config_path, + r#" +host = "0.0.0.0" +port = 9999 +db_path = "/var/lib/memory.duckdb" + +[extraction] +enabled = false +max_tokens = 512 +temperature = 0.3 +"#, + ) + .unwrap(); + + let config = Config::load(Some(config_path.to_str().unwrap())).unwrap(); + assert!(!config.extraction.enabled); + assert_eq!(config.extraction.max_tokens, 512); + assert!((config.extraction.temperature - 0.3).abs() < f32::EPSILON); + } + #[test] fn test_provenance_config_uses_defaults_when_omitted() { let config = Config::default(); @@ -270,6 +346,14 @@ sanitization_enabled = false assert!(config.provenance.sanitization_enabled); } + #[test] + fn test_extraction_config_uses_defaults_when_omitted() { + let config = Config::default(); + assert!(config.extraction.enabled); + assert_eq!(config.extraction.max_tokens, 256); + assert!((config.extraction.temperature - 0.1).abs() < f32::EPSILON); + } + #[test] fn test_retrieval_config_uses_defaults_when_omitted() { let dir = tempfile::tempdir().unwrap(); diff --git a/services/memory/src/embedding/mod.rs b/services/memory/src/embedding/mod.rs index a24a2ae..cad7308 100644 --- a/services/memory/src/embedding/mod.rs +++ b/services/memory/src/embedding/mod.rs @@ -249,6 +249,12 @@ pub mod mock { pub should_fail: AtomicBool, } + impl Default for MockEmbeddingGenerator { + fn default() -> Self { + Self::new() + } + } + impl MockEmbeddingGenerator { /// Create a mock generator that succeeds. pub fn new() -> Self { @@ -780,7 +786,7 @@ mod tests { ..Default::default() }; - let requests = vec![ + let requests = [ EmbeddingRequest { memory_id: "mem-1".to_string(), field: EmbeddingField::Name, diff --git a/services/memory/src/extraction/mod.rs b/services/memory/src/extraction/mod.rs new file mode 100644 index 0000000..1ea20e1 --- /dev/null +++ b/services/memory/src/extraction/mod.rs @@ -0,0 +1,500 @@ +//! Post-retrieval extraction step for the Memory Service. +//! +//! After the 4-stage retrieval pipeline returns candidates, the extraction step +//! makes a lightweight model call via the Model Gateway `Inference` RPC to extract +//! only the segment of each candidate's corpus that is relevant to the query. +//! This extracted segment -- not the full corpus -- enters agent context. + +pub mod prompt; + +use crate::config::ExtractionConfig; +use crate::retrieval::RetrievalCandidate; +use llm_multiverse_proto::llm_multiverse::v1::{ + model_gateway_service_client::ModelGatewayServiceClient, InferenceParams, InferenceRequest, + SessionContext, TaskComplexity, +}; +use prompt::ExtractionResult; +use tonic::transport::Channel; + +/// Errors specific to the extraction step. +#[derive(Debug, thiserror::Error)] +pub enum ExtractionError { + /// The Model Gateway Inference call failed. + #[error("inference call failed: {0}")] + InferenceFailed(#[from] tonic::Status), + + /// Connection to Model Gateway failed. + #[error("connection error: {0}")] + ConnectionError(#[from] tonic::transport::Error), + + /// Failed to parse the model's extraction response. + #[error("failed to parse extraction response: {0}")] + ParseError(String), +} + +/// Client for performing post-retrieval extraction via the Model Gateway. +/// +/// Wraps `ModelGatewayServiceClient` and calls the `Inference` RPC with a +/// structured prompt instructing the model to extract the relevant segment +/// from a memory corpus given a query. +pub struct ExtractionClient { + client: ModelGatewayServiceClient, + config: ExtractionConfig, +} + +impl ExtractionClient { + /// Connect to the Model Gateway at the given endpoint. + /// + /// # Errors + /// + /// Returns [`ExtractionError::ConnectionError`] if the gRPC channel + /// cannot be established. + pub async fn connect( + endpoint: &str, + config: ExtractionConfig, + ) -> Result { + let client = ModelGatewayServiceClient::connect(endpoint.to_string()).await?; + Ok(Self { client, config }) + } + + /// Extract the relevant segment from a memory corpus given a query. + /// + /// Calls the Model Gateway `Inference` RPC with a structured prompt + /// instructing the model to: + /// 1. Read the corpus text + /// 2. Identify the segment most relevant to the query + /// 3. Return ONLY that segment, verbatim or lightly paraphrased + /// 4. Include a confidence score (0.0-1.0) + /// + /// Uses `TaskComplexity::Simple` for lightweight model routing. + pub async fn extract( + &self, + context: &SessionContext, + query: &str, + corpus: &str, + memory_name: &str, + memory_id: &str, + ) -> Result { + // Skip extraction for empty corpus + if corpus.is_empty() { + return Ok(ExtractionResult { + memory_id: memory_id.to_string(), + segment: String::new(), + confidence: 0.0, + }); + } + + let extraction_prompt = prompt::build_extraction_prompt(query, corpus, memory_name); + + let request = InferenceRequest { + params: Some(InferenceParams { + context: Some(context.clone()), + prompt: extraction_prompt, + task_complexity: TaskComplexity::Simple.into(), + max_tokens: self.config.max_tokens, + temperature: Some(self.config.temperature), + top_p: None, + stop_sequences: vec![], + }), + }; + + let mut client = self.client.clone(); + let response = client.inference(request).await?.into_inner(); + + Ok(prompt::parse_extraction_response(&response.text, memory_id)) + } + + /// Extract segments for multiple candidates. + /// + /// Processes candidates sequentially to avoid overloading the gateway. + /// On failure for a single candidate, that candidate's extraction is + /// skipped (returns the full corpus as fallback) and processing continues. + pub async fn extract_batch( + &self, + context: &SessionContext, + query: &str, + candidates: &[RetrievalCandidate], + ) -> Vec { + let mut results = Vec::with_capacity(candidates.len()); + + for candidate in candidates { + match self + .extract( + context, + query, + &candidate.corpus, + &candidate.name, + &candidate.memory_id, + ) + .await + { + Ok(result) => results.push(result), + Err(e) => { + tracing::warn!( + memory_id = %candidate.memory_id, + error = %e, + "Extraction failed for candidate, using full corpus as fallback" + ); + results.push(ExtractionResult { + memory_id: candidate.memory_id.clone(), + segment: candidate.corpus.clone(), + confidence: 0.0, + }); + } + } + } + + results + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::ExtractionConfig; + + #[test] + fn test_extraction_config_defaults() { + let config = ExtractionConfig::default(); + assert!(config.enabled); + assert_eq!(config.max_tokens, 256); + assert!((config.temperature - 0.1).abs() < f32::EPSILON); + } + + #[test] + fn test_extraction_error_display() { + let err = ExtractionError::ParseError("bad format".to_string()); + assert_eq!( + err.to_string(), + "failed to parse extraction response: bad format" + ); + + let err2 = ExtractionError::InferenceFailed(tonic::Status::internal("test error")); + assert!(err2.to_string().contains("inference call failed")); + } + + #[test] + fn test_extraction_result_fields() { + let result = ExtractionResult { + memory_id: "mem-1".to_string(), + segment: "relevant text".to_string(), + confidence: 0.85, + }; + assert_eq!(result.memory_id, "mem-1"); + assert_eq!(result.segment, "relevant text"); + assert!((result.confidence - 0.85).abs() < f32::EPSILON); + } + + /// Mock Model Gateway server for extraction integration tests. + mod mock_gateway { + use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{ + ModelGatewayService, ModelGatewayServiceServer, + }; + use llm_multiverse_proto::llm_multiverse::v1::{ + GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest, + InferenceResponse, IsModelReadyRequest, IsModelReadyResponse, + StreamInferenceRequest, StreamInferenceResponse, + }; + use std::net::SocketAddr; + use std::sync::atomic::{AtomicBool, Ordering}; + use tonic::{Request, Response, Status}; + + pub struct MockGateway { + /// When true, the inference call returns an error. + pub should_fail: AtomicBool, + /// Canned response text for inference calls. + pub response_text: String, + } + + impl MockGateway { + pub fn new(response_text: &str) -> Self { + Self { + should_fail: AtomicBool::new(false), + response_text: response_text.to_string(), + } + } + + pub fn failing() -> Self { + let gw = Self::new(""); + gw.should_fail.store(true, Ordering::Relaxed); + gw + } + } + + #[tonic::async_trait] + impl ModelGatewayService for MockGateway { + type StreamInferenceStream = + tokio_stream::Empty>; + + async fn stream_inference( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("not needed for extraction tests")) + } + + async fn inference( + &self, + _request: Request, + ) -> Result, Status> { + if self.should_fail.load(Ordering::Relaxed) { + return Err(Status::internal("mock inference error")); + } + Ok(Response::new(InferenceResponse { + text: self.response_text.clone(), + finish_reason: "stop".to_string(), + tokens_used: 50, + })) + } + + async fn generate_embedding( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("not needed for extraction tests")) + } + + async fn is_model_ready( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("not needed for extraction tests")) + } + } + + /// Start a mock Model Gateway server and return its address. + pub async fn start_mock_gateway(gateway: MockGateway) -> SocketAddr { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + tonic::transport::Server::builder() + .add_service(ModelGatewayServiceServer::new(gateway)) + .serve_with_incoming(incoming) + .await + .unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + addr + } + } + + #[tokio::test] + async fn test_extract_single_success() { + let gateway = + mock_gateway::MockGateway::new("SEGMENT: The relevant part.\nCONFIDENCE: 0.85"); + let addr = mock_gateway::start_mock_gateway(gateway).await; + + let client = ExtractionClient::connect( + &format!("http://{addr}"), + ExtractionConfig::default(), + ) + .await + .unwrap(); + + let ctx = SessionContext { + session_id: "test-sess".to_string(), + ..Default::default() + }; + + let result = client + .extract(&ctx, "test query", "full corpus text here", "mem name", "mem-1") + .await + .unwrap(); + + assert_eq!(result.memory_id, "mem-1"); + assert_eq!(result.segment, "The relevant part."); + assert!((result.confidence - 0.85).abs() < f32::EPSILON); + } + + #[tokio::test] + async fn test_extract_single_gateway_error() { + let gateway = mock_gateway::MockGateway::failing(); + let addr = mock_gateway::start_mock_gateway(gateway).await; + + let client = ExtractionClient::connect( + &format!("http://{addr}"), + ExtractionConfig::default(), + ) + .await + .unwrap(); + + let ctx = SessionContext { + session_id: "test-sess".to_string(), + ..Default::default() + }; + + let result = client + .extract(&ctx, "test query", "corpus", "name", "mem-1") + .await; + + assert!(result.is_err()); + match result.unwrap_err() { + ExtractionError::InferenceFailed(status) => { + assert_eq!(status.code(), tonic::Code::Internal); + } + other => panic!("expected InferenceFailed, got: {other:?}"), + } + } + + #[tokio::test] + async fn test_extract_single_empty_corpus() { + let gateway = mock_gateway::MockGateway::new("should not be called"); + let addr = mock_gateway::start_mock_gateway(gateway).await; + + let client = ExtractionClient::connect( + &format!("http://{addr}"), + ExtractionConfig::default(), + ) + .await + .unwrap(); + + let ctx = SessionContext { + session_id: "test-sess".to_string(), + ..Default::default() + }; + + let result = client + .extract(&ctx, "test query", "", "mem name", "mem-1") + .await + .unwrap(); + + assert_eq!(result.memory_id, "mem-1"); + assert!(result.segment.is_empty()); + assert!((result.confidence - 0.0).abs() < f32::EPSILON); + } + + #[tokio::test] + async fn test_extract_batch_all_success() { + let gateway = + mock_gateway::MockGateway::new("SEGMENT: Extracted text.\nCONFIDENCE: 0.9"); + let addr = mock_gateway::start_mock_gateway(gateway).await; + + let client = ExtractionClient::connect( + &format!("http://{addr}"), + ExtractionConfig::default(), + ) + .await + .unwrap(); + + let ctx = SessionContext { + session_id: "test-sess".to_string(), + ..Default::default() + }; + + let candidates = vec![ + RetrievalCandidate { + memory_id: "mem-1".to_string(), + name: "First".to_string(), + description: "Desc 1".to_string(), + corpus: "Corpus 1".to_string(), + tags: vec![], + correlating_ids: vec![], + provenance: 1, + created_at: None, + last_accessed: None, + access_count: 0, + name_score: 0.9, + description_score: 0.8, + corpus_score: 0.7, + final_score: 0.8, + provenance_metadata: None, + }, + RetrievalCandidate { + memory_id: "mem-2".to_string(), + name: "Second".to_string(), + description: "Desc 2".to_string(), + corpus: "Corpus 2".to_string(), + tags: vec![], + correlating_ids: vec![], + provenance: 1, + created_at: None, + last_accessed: None, + access_count: 0, + name_score: 0.85, + description_score: 0.75, + corpus_score: 0.65, + final_score: 0.75, + provenance_metadata: None, + }, + ]; + + let results = client.extract_batch(&ctx, "test query", &candidates).await; + assert_eq!(results.len(), 2); + assert_eq!(results[0].memory_id, "mem-1"); + assert_eq!(results[0].segment, "Extracted text."); + assert_eq!(results[1].memory_id, "mem-2"); + assert_eq!(results[1].segment, "Extracted text."); + } + + #[tokio::test] + async fn test_extract_batch_partial_failure() { + // Use a gateway that fails -- all candidates will get fallback + let gateway = mock_gateway::MockGateway::failing(); + let addr = mock_gateway::start_mock_gateway(gateway).await; + + let client = ExtractionClient::connect( + &format!("http://{addr}"), + ExtractionConfig::default(), + ) + .await + .unwrap(); + + let ctx = SessionContext { + session_id: "test-sess".to_string(), + ..Default::default() + }; + + let candidates = vec![ + RetrievalCandidate { + memory_id: "mem-1".to_string(), + name: "First".to_string(), + description: "Desc 1".to_string(), + corpus: "Full corpus text 1".to_string(), + tags: vec![], + correlating_ids: vec![], + provenance: 1, + created_at: None, + last_accessed: None, + access_count: 0, + name_score: 0.9, + description_score: 0.8, + corpus_score: 0.7, + final_score: 0.8, + provenance_metadata: None, + }, + RetrievalCandidate { + memory_id: "mem-2".to_string(), + name: "Second".to_string(), + description: "Desc 2".to_string(), + corpus: "Full corpus text 2".to_string(), + tags: vec![], + correlating_ids: vec![], + provenance: 1, + created_at: None, + last_accessed: None, + access_count: 0, + name_score: 0.85, + description_score: 0.75, + corpus_score: 0.65, + final_score: 0.75, + provenance_metadata: None, + }, + ]; + + let results = client.extract_batch(&ctx, "test query", &candidates).await; + + // All should have fallback (full corpus, confidence 0.0) + assert_eq!(results.len(), 2); + assert_eq!(results[0].memory_id, "mem-1"); + assert_eq!(results[0].segment, "Full corpus text 1"); + assert!((results[0].confidence - 0.0).abs() < f32::EPSILON); + assert_eq!(results[1].memory_id, "mem-2"); + assert_eq!(results[1].segment, "Full corpus text 2"); + assert!((results[1].confidence - 0.0).abs() < f32::EPSILON); + } +} diff --git a/services/memory/src/extraction/prompt.rs b/services/memory/src/extraction/prompt.rs new file mode 100644 index 0000000..9288fb3 --- /dev/null +++ b/services/memory/src/extraction/prompt.rs @@ -0,0 +1,181 @@ +//! Extraction prompt template and response parsing. +//! +//! Provides the prompt template for instructing a lightweight model to extract +//! the relevant segment from a memory corpus, and a parser for the structured +//! response format. + +/// Result of extracting a relevant segment from a memory corpus. +#[derive(Debug, Clone)] +pub struct ExtractionResult { + /// The memory ID this extraction belongs to. + pub memory_id: String, + /// The extracted relevant segment. + pub segment: String, + /// Confidence score (0.0-1.0) parsed from the model's response. + pub confidence: f32, +} + +/// Build the extraction prompt for a single memory entry. +/// +/// The prompt instructs the model to extract only the relevant segment +/// from the corpus, given the query context. Output format is structured +/// to enable reliable parsing. +pub fn build_extraction_prompt(query: &str, corpus: &str, memory_name: &str) -> String { + format!( + "Given the following search query and memory content, extract ONLY the segment \ +of the content that is most relevant to the query. Return a concise, focused \ +excerpt — not the full content.\n\ +\n\ +Query: {query}\n\ +Memory name: {memory_name}\n\ +\n\ +Content:\n\ +---\n\ +{corpus}\n\ +---\n\ +\n\ +Respond in this exact format:\n\ +SEGMENT: \n\ +CONFIDENCE: <0.0-1.0 score indicating relevance>" + ) +} + +/// Parse the model's extraction response into segment and confidence. +/// +/// Expects the format: +/// SEGMENT: +/// CONFIDENCE: +/// +/// Falls back gracefully: +/// - If CONFIDENCE is missing or unparseable, defaults to 0.5 +/// - If SEGMENT is missing, returns the full response as the segment +/// - If response is empty, returns empty segment with confidence 0.0 +pub fn parse_extraction_response(response: &str, memory_id: &str) -> ExtractionResult { + if response.is_empty() { + return ExtractionResult { + memory_id: memory_id.to_string(), + segment: String::new(), + confidence: 0.0, + }; + } + + let mut segment: Option = None; + let mut confidence: Option = None; + + // Find the SEGMENT: and CONFIDENCE: markers + if let Some(seg_start) = response.find("SEGMENT:") { + let seg_text_start = seg_start + "SEGMENT:".len(); + let remaining = &response[seg_text_start..]; + + // The segment extends until "CONFIDENCE:" or end of string + let seg_end = remaining.find("CONFIDENCE:"); + let raw_segment = match seg_end { + Some(end) => &remaining[..end], + None => remaining, + }; + segment = Some(raw_segment.trim().to_string()); + } + + if let Some(conf_start) = response.find("CONFIDENCE:") { + let conf_text_start = conf_start + "CONFIDENCE:".len(); + let remaining = &response[conf_text_start..]; + let conf_str = remaining.split_whitespace().next().unwrap_or(""); + if let Ok(val) = conf_str.parse::() { + // Clamp to [0.0, 1.0] + confidence = Some(val.clamp(0.0, 1.0)); + } + } + + ExtractionResult { + memory_id: memory_id.to_string(), + segment: segment.unwrap_or_else(|| response.trim().to_string()), + confidence: confidence.unwrap_or(0.5), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_extraction_prompt_contains_query() { + let prompt = build_extraction_prompt("how to sort arrays", "some corpus", "mem name"); + assert!(prompt.contains("how to sort arrays")); + } + + #[test] + fn test_build_extraction_prompt_contains_corpus() { + let prompt = build_extraction_prompt("query", "the corpus content here", "mem name"); + assert!(prompt.contains("the corpus content here")); + } + + #[test] + fn test_build_extraction_prompt_contains_memory_name() { + let prompt = build_extraction_prompt("query", "corpus", "Rust Sort Algorithm"); + assert!(prompt.contains("Rust Sort Algorithm")); + } + + #[test] + fn test_parse_extraction_response_valid() { + let response = "SEGMENT: The quicksort algorithm divides the array.\nCONFIDENCE: 0.8"; + let result = parse_extraction_response(response, "mem-1"); + assert_eq!(result.memory_id, "mem-1"); + assert_eq!(result.segment, "The quicksort algorithm divides the array."); + assert!((result.confidence - 0.8).abs() < f32::EPSILON); + } + + #[test] + fn test_parse_extraction_response_missing_confidence() { + let response = "SEGMENT: Some relevant text here."; + let result = parse_extraction_response(response, "mem-2"); + assert_eq!(result.segment, "Some relevant text here."); + assert!((result.confidence - 0.5).abs() < f32::EPSILON); + } + + #[test] + fn test_parse_extraction_response_invalid_confidence() { + let response = "SEGMENT: Relevant text.\nCONFIDENCE: not-a-number"; + let result = parse_extraction_response(response, "mem-3"); + assert_eq!(result.segment, "Relevant text."); + assert!((result.confidence - 0.5).abs() < f32::EPSILON); + } + + #[test] + fn test_parse_extraction_response_missing_segment() { + let response = "This is just some text without any markers."; + let result = parse_extraction_response(response, "mem-4"); + assert_eq!(result.segment, "This is just some text without any markers."); + assert!((result.confidence - 0.5).abs() < f32::EPSILON); + } + + #[test] + fn test_parse_extraction_response_multiline_segment() { + let response = + "SEGMENT: Line one of the segment.\nLine two of the segment.\nCONFIDENCE: 0.9"; + let result = parse_extraction_response(response, "mem-5"); + assert_eq!( + result.segment, + "Line one of the segment.\nLine two of the segment." + ); + assert!((result.confidence - 0.9).abs() < f32::EPSILON); + } + + #[test] + fn test_parse_extraction_response_empty() { + let result = parse_extraction_response("", "mem-6"); + assert_eq!(result.memory_id, "mem-6"); + assert!(result.segment.is_empty()); + assert!((result.confidence - 0.0).abs() < f32::EPSILON); + } + + #[test] + fn test_parse_extraction_response_confidence_clamped() { + let response = "SEGMENT: text\nCONFIDENCE: 1.5"; + let result = parse_extraction_response(response, "mem-7"); + assert!((result.confidence - 1.0).abs() < f32::EPSILON); + + let response2 = "SEGMENT: text\nCONFIDENCE: -0.3"; + let result2 = parse_extraction_response(response2, "mem-8"); + assert!((result2.confidence - 0.0).abs() < f32::EPSILON); + } +} diff --git a/services/memory/src/lib.rs b/services/memory/src/lib.rs index 58b00c8..6c33196 100644 --- a/services/memory/src/lib.rs +++ b/services/memory/src/lib.rs @@ -1,6 +1,7 @@ pub mod config; pub mod db; pub mod embedding; +pub mod extraction; pub mod provenance; pub mod retrieval; pub mod service; diff --git a/services/memory/src/main.rs b/services/memory/src/main.rs index fc8075c..0e7d4f0 100644 --- a/services/memory/src/main.rs +++ b/services/memory/src/main.rs @@ -4,6 +4,7 @@ use llm_multiverse_proto::llm_multiverse::v1::memory_service_server::MemoryServi use memory_service::config::Config; use memory_service::db::DuckDbManager; use memory_service::embedding::EmbeddingClient; +use memory_service::extraction::ExtractionClient; use memory_service::service::MemoryServiceImpl; use tonic::transport::Server; use tracing_subscriber::EnvFilter; @@ -31,7 +32,8 @@ async fn main() -> anyhow::Result<()> { let db = Arc::new(DuckDbManager::new(&config.db_path)?); let retrieval_config = config.retrieval.clone(); let provenance_config = config.provenance.clone(); - let mut memory_service = MemoryServiceImpl::new(db, retrieval_config, provenance_config); + let extraction_config = config.extraction.clone(); + let mut memory_service = MemoryServiceImpl::new(db, retrieval_config, provenance_config, extraction_config.clone()); // Connect to Model Gateway for embedding generation if configured. if let Some(ref endpoint) = config.embedding_endpoint { @@ -50,6 +52,25 @@ async fn main() -> anyhow::Result<()> { } } + // Connect to Model Gateway for extraction if configured and enabled. + if extraction_config.enabled { + if let Some(ref endpoint) = config.embedding_endpoint { + match ExtractionClient::connect(endpoint, extraction_config).await { + Ok(client) => { + tracing::info!(endpoint = %endpoint, "Connected to Model Gateway for extraction"); + memory_service = memory_service.with_extraction_client(client); + } + Err(e) => { + tracing::warn!( + endpoint = %endpoint, + error = %e, + "Model Gateway unavailable for extraction — extraction disabled" + ); + } + } + } + } + tracing::info!(%addr, "Memory Service listening"); Server::builder() diff --git a/services/memory/src/retrieval/pipeline.rs b/services/memory/src/retrieval/pipeline.rs index 2ba19cd..269f3a1 100644 --- a/services/memory/src/retrieval/pipeline.rs +++ b/services/memory/src/retrieval/pipeline.rs @@ -199,7 +199,7 @@ mod tests { for r in &results { assert!(!r.memory_id.is_empty()); assert!(!r.name.is_empty()); - assert!(r.final_score > 0.0 || r.final_score == 0.0); + assert!(r.final_score >= 0.0); } // Results should be sorted by final_score descending diff --git a/services/memory/src/retrieval/stage3.rs b/services/memory/src/retrieval/stage3.rs index 43491b7..6672009 100644 --- a/services/memory/src/retrieval/stage3.rs +++ b/services/memory/src/retrieval/stage3.rs @@ -249,6 +249,7 @@ mod tests { use crate::embedding::store::format_vector_literal; /// Helper to insert a full memory with all three embeddings, tags, and correlations. + #[allow(clippy::too_many_arguments)] fn insert_full_memory( conn: &Connection, id: &str, diff --git a/services/memory/src/service.rs b/services/memory/src/service.rs index 8854d01..2152d93 100644 --- a/services/memory/src/service.rs +++ b/services/memory/src/service.rs @@ -1,8 +1,9 @@ use std::sync::Arc; -use crate::config::{ProvenanceConfig, RetrievalConfig}; +use crate::config::{ExtractionConfig, ProvenanceConfig, RetrievalConfig}; use crate::db::DuckDbManager; use crate::embedding::{EmbeddingClient, EmbeddingGenerator}; +use crate::extraction::ExtractionClient; use crate::provenance::sanitizer::ContentSanitizer; use crate::provenance::TrustLevel; use crate::retrieval::{self, RetrievalCandidate, RetrievalError, RetrievalParams}; @@ -26,7 +27,9 @@ use tonic::{Request, Response, Status}; pub struct MemoryServiceImpl { db: Arc, embedding_client: Option>>, + extraction_client: Option>>, retrieval_config: RetrievalConfig, + extraction_config: ExtractionConfig, /// Used by write_memory for provenance requirement checks (issue #34). #[allow(dead_code)] provenance_config: ProvenanceConfig, @@ -41,11 +44,14 @@ impl MemoryServiceImpl { db: Arc, retrieval_config: RetrievalConfig, provenance_config: ProvenanceConfig, + extraction_config: ExtractionConfig, ) -> Self { Self { db, embedding_client: None, + extraction_client: None, retrieval_config, + extraction_config, provenance_config, sanitizer: ContentSanitizer::new(), } @@ -61,6 +67,16 @@ impl MemoryServiceImpl { self.embedding_client = Some(Arc::new(Mutex::new(client))); self } + + /// Attach an extraction client for post-retrieval segment extraction. + /// + /// When an extraction client is attached and extraction is enabled, the + /// `query_memory` handler will call the Model Gateway `Inference` RPC to + /// extract the relevant segment from each candidate's corpus. + pub fn with_extraction_client(mut self, client: ExtractionClient) -> Self { + self.extraction_client = Some(Arc::new(Mutex::new(client))); + self + } } /// Convert a `NaiveDateTime` to a protobuf `Timestamp`. @@ -180,17 +196,45 @@ impl MemoryService for MemoryServiceImpl { } })?; + // Run extraction if enabled + let should_extract = self.extraction_config.enabled + && !req.skip_extraction + && self.extraction_client.is_some(); + + let extraction_results = if should_extract { + let extraction_client = self.extraction_client.as_ref().expect("checked above"); + let client = extraction_client.lock().await; + let results = client + .extract_batch(&ctx, &req.query, &candidates) + .await; + drop(client); + Some(results) + } else { + None + }; + // Stream results via channel let (tx, rx) = tokio::sync::mpsc::channel(candidates.len().max(1)); tokio::spawn(async move { for (rank, candidate) in candidates.into_iter().enumerate() { + let (extracted_segment, extraction_confidence) = + if let Some(ref extractions) = extraction_results { + if let Some(result) = extractions.get(rank) { + (Some(result.segment.clone()), Some(result.confidence)) + } else { + (None, None) + } + } else { + (None, None) + }; + let response = QueryMemoryResponse { rank: (rank + 1) as u32, entry: Some(candidate_to_memory_entry(&candidate)), cosine_similarity: candidate.final_score, is_cached: false, - cached_extracted_segment: None, - extraction_confidence: None, + cached_extracted_segment: extracted_segment, + extraction_confidence, }; if tx.send(Ok(response)).await.is_err() { break; // Client disconnected @@ -321,7 +365,7 @@ impl MemoryService for MemoryServiceImpl { #[cfg(test)] mod tests { use super::*; - use crate::config::{ProvenanceConfig, RetrievalConfig}; + use crate::config::{ExtractionConfig, ProvenanceConfig, RetrievalConfig}; use llm_multiverse_proto::llm_multiverse::v1::{MemoryEntry, SessionContext}; fn valid_ctx() -> SessionContext { @@ -337,6 +381,7 @@ mod tests { Arc::new(db), RetrievalConfig::default(), ProvenanceConfig::default(), + ExtractionConfig::default(), ) } @@ -472,7 +517,6 @@ mod tests { #[tokio::test] async fn test_query_memory_returns_streamed_results() { use crate::db::schema::{ensure_hnsw_index, EMBEDDING_DIM}; - use crate::embedding::mock::MockEmbeddingGenerator; use crate::embedding::store::format_vector_literal; use tokio_stream::StreamExt; @@ -508,16 +552,14 @@ mod tests { }) .expect("DB setup failed"); - // Create service with mock embedding client - let mock_gen = MockEmbeddingGenerator::new(); - // Wrap the mock in an EmbeddingClient-compatible way via the trait - // We need to use the service's embedding_client field which expects EmbeddingClient, - // but we can set it directly for testing. + // Create service with mock embedding client via a real mock gateway server. + // (MockEmbeddingGenerator can't be used directly since the field expects EmbeddingClient.) let db_arc = Arc::new(db); let mut svc = MemoryServiceImpl::new( db_arc, RetrievalConfig::default(), ProvenanceConfig::default(), + ExtractionConfig::default(), ); // Set embedding client to mock by storing it as Arc> // Since the field type is Option>>, we need a different approach. @@ -593,7 +635,7 @@ mod tests { .expect("connect failed"); svc = svc.with_embedding_client(embedding_client); - // Call query_memory + // Call query_memory (no extraction client attached, so extraction is skipped) let response = svc .query_memory(Request::new(QueryMemoryRequest { context: Some(valid_ctx()), @@ -641,6 +683,12 @@ mod tests { assert!(svc.embedding_client.is_none()); } + #[test] + fn test_service_new_returns_none_extraction_client() { + let svc = test_service(); + assert!(svc.extraction_client.is_none()); + } + #[test] fn test_service_new_returns_none_embedding_client() { let db = DuckDbManager::in_memory().expect("failed to create test DB"); @@ -648,6 +696,7 @@ mod tests { Arc::new(db), RetrievalConfig::default(), ProvenanceConfig::default(), + ExtractionConfig::default(), ); assert!(svc.embedding_client.is_none()); } @@ -809,6 +858,7 @@ mod tests { db_arc.clone(), RetrievalConfig::default(), ProvenanceConfig::default(), + ExtractionConfig::default(), ); let result = svc @@ -832,4 +882,354 @@ mod tests { assert!(prov.is_revoked); assert_eq!(prov.trust_level, crate::provenance::TrustLevel::Revoked); } + + /// Helper to set up a mock gateway server that handles both embedding and inference. + /// Returns the gateway address. + mod extraction_test_helpers { + use crate::db::schema::EMBEDDING_DIM; + use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{ + ModelGatewayService, ModelGatewayServiceServer, + }; + use llm_multiverse_proto::llm_multiverse::v1::{ + GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest, + InferenceResponse, IsModelReadyRequest, IsModelReadyResponse, + StreamInferenceRequest, StreamInferenceResponse, + }; + use std::net::SocketAddr; + use tonic::{Request, Response, Status}; + + pub struct FullMockGateway; + + #[tonic::async_trait] + impl ModelGatewayService for FullMockGateway { + type StreamInferenceStream = + tokio_stream::Empty>; + + async fn stream_inference( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("not needed")) + } + + async fn inference( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(InferenceResponse { + text: "SEGMENT: Extracted relevant text.\nCONFIDENCE: 0.85".to_string(), + finish_reason: "stop".to_string(), + tokens_used: 30, + })) + } + + async fn generate_embedding( + &self, + _request: Request, + ) -> Result, Status> { + let embedding = vec![0.10_f32; EMBEDDING_DIM]; + Ok(Response::new(GenerateEmbeddingResponse { + embedding, + dimensions: EMBEDDING_DIM as u32, + })) + } + + async fn is_model_ready( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("not needed")) + } + } + + pub async fn start_full_mock_gateway() -> SocketAddr { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(ModelGatewayServiceServer::new(FullMockGateway)) + .serve_with_incoming(incoming) + .await + .unwrap(); + }); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + addr + } + } + + /// Helper to populate DB with test data for extraction tests. + fn populate_test_db(db: &DuckDbManager) { + use crate::db::schema::{ensure_hnsw_index, EMBEDDING_DIM}; + use crate::embedding::store::format_vector_literal; + + db.with_connection(|conn| { + for i in 1..=2 { + let id = format!("mem-{i}"); + let name = format!("Memory {i}"); + let desc = format!("Description {i}"); + let corpus = format!("Full corpus content for memory {i}"); + conn.execute( + "INSERT INTO memories (id, name, description, corpus, provenance) VALUES (?, ?, ?, ?, 1)", + duckdb::params![id, name, desc, corpus], + )?; + + let embedding_value = 0.10; + for emb_type in &["name", "description", "corpus"] { + let vector = vec![embedding_value; EMBEDDING_DIM]; + let lit = format_vector_literal(&vector); + conn.execute( + &format!( + "INSERT INTO embeddings (memory_id, embedding_type, vector) VALUES (?, '{emb_type}', {lit})" + ), + duckdb::params![id], + )?; + } + } + ensure_hnsw_index(conn)?; + Ok(()) + }) + .expect("DB setup failed"); + } + + #[tokio::test] + async fn test_query_memory_with_extraction() { + use tokio_stream::StreamExt; + + let db = DuckDbManager::in_memory().expect("failed to create test DB"); + populate_test_db(&db); + let db_arc = Arc::new(db); + + let addr = extraction_test_helpers::start_full_mock_gateway().await; + let endpoint = format!("http://{addr}"); + + let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint) + .await + .expect("connect failed"); + let extraction_client = crate::extraction::ExtractionClient::connect( + &endpoint, + ExtractionConfig::default(), + ) + .await + .expect("connect failed"); + + let svc = MemoryServiceImpl::new( + db_arc, + RetrievalConfig::default(), + ProvenanceConfig::default(), + ExtractionConfig::default(), + ) + .with_embedding_client(embedding_client) + .with_extraction_client(extraction_client); + + let response = svc + .query_memory(Request::new(QueryMemoryRequest { + context: Some(valid_ctx()), + query: "test query".into(), + memory_type: String::new(), + limit: 10, + ..Default::default() + })) + .await + .expect("query_memory should succeed"); + + let mut stream = response.into_inner(); + let mut results = Vec::new(); + while let Some(item) = stream.next().await { + results.push(item.expect("stream item should be Ok")); + } + + assert!(!results.is_empty()); + for r in &results { + assert!( + r.cached_extracted_segment.is_some(), + "extracted segment should be populated" + ); + assert_eq!( + r.cached_extracted_segment.as_deref(), + Some("Extracted relevant text.") + ); + assert!( + r.extraction_confidence.is_some(), + "extraction confidence should be set" + ); + let conf = r.extraction_confidence.unwrap(); + assert!((conf - 0.85).abs() < f32::EPSILON); + } + } + + #[tokio::test] + async fn test_query_memory_skip_extraction() { + use tokio_stream::StreamExt; + + let db = DuckDbManager::in_memory().expect("failed to create test DB"); + populate_test_db(&db); + let db_arc = Arc::new(db); + + let addr = extraction_test_helpers::start_full_mock_gateway().await; + let endpoint = format!("http://{addr}"); + + let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint) + .await + .expect("connect failed"); + let extraction_client = crate::extraction::ExtractionClient::connect( + &endpoint, + ExtractionConfig::default(), + ) + .await + .expect("connect failed"); + + let svc = MemoryServiceImpl::new( + db_arc, + RetrievalConfig::default(), + ProvenanceConfig::default(), + ExtractionConfig::default(), + ) + .with_embedding_client(embedding_client) + .with_extraction_client(extraction_client); + + let response = svc + .query_memory(Request::new(QueryMemoryRequest { + context: Some(valid_ctx()), + query: "test query".into(), + memory_type: String::new(), + limit: 10, + skip_extraction: true, + ..Default::default() + })) + .await + .expect("query_memory should succeed"); + + let mut stream = response.into_inner(); + let mut results = Vec::new(); + while let Some(item) = stream.next().await { + results.push(item.expect("stream item should be Ok")); + } + + assert!(!results.is_empty()); + for r in &results { + assert!( + r.cached_extracted_segment.is_none(), + "extraction should be skipped when skip_extraction=true" + ); + assert!( + r.extraction_confidence.is_none(), + "extraction confidence should not be set when skipped" + ); + } + } + + #[tokio::test] + async fn test_query_memory_extraction_disabled_config() { + use tokio_stream::StreamExt; + + let db = DuckDbManager::in_memory().expect("failed to create test DB"); + populate_test_db(&db); + let db_arc = Arc::new(db); + + let addr = extraction_test_helpers::start_full_mock_gateway().await; + let endpoint = format!("http://{addr}"); + + let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint) + .await + .expect("connect failed"); + let extraction_client = crate::extraction::ExtractionClient::connect( + &endpoint, + ExtractionConfig::default(), + ) + .await + .expect("connect failed"); + + // Disable extraction via config + let extraction_config = ExtractionConfig { + enabled: false, + ..Default::default() + }; + + let svc = MemoryServiceImpl::new( + db_arc, + RetrievalConfig::default(), + ProvenanceConfig::default(), + extraction_config, + ) + .with_embedding_client(embedding_client) + .with_extraction_client(extraction_client); + + let response = svc + .query_memory(Request::new(QueryMemoryRequest { + context: Some(valid_ctx()), + query: "test query".into(), + memory_type: String::new(), + limit: 10, + ..Default::default() + })) + .await + .expect("query_memory should succeed"); + + let mut stream = response.into_inner(); + let mut results = Vec::new(); + while let Some(item) = stream.next().await { + results.push(item.expect("stream item should be Ok")); + } + + assert!(!results.is_empty()); + for r in &results { + assert!( + r.cached_extracted_segment.is_none(), + "extraction should be skipped when config disabled" + ); + } + } + + #[tokio::test] + async fn test_query_memory_no_extraction_client() { + use tokio_stream::StreamExt; + + let db = DuckDbManager::in_memory().expect("failed to create test DB"); + populate_test_db(&db); + let db_arc = Arc::new(db); + + let addr = extraction_test_helpers::start_full_mock_gateway().await; + let endpoint = format!("http://{addr}"); + + let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint) + .await + .expect("connect failed"); + + // No extraction client attached + let svc = MemoryServiceImpl::new( + db_arc, + RetrievalConfig::default(), + ProvenanceConfig::default(), + ExtractionConfig::default(), + ) + .with_embedding_client(embedding_client); + + let response = svc + .query_memory(Request::new(QueryMemoryRequest { + context: Some(valid_ctx()), + query: "test query".into(), + memory_type: String::new(), + limit: 10, + ..Default::default() + })) + .await + .expect("query_memory should succeed"); + + let mut stream = response.into_inner(); + let mut results = Vec::new(); + while let Some(item) = stream.next().await { + results.push(item.expect("stream item should be Ok")); + } + + assert!(!results.is_empty()); + for r in &results { + assert!( + r.cached_extracted_segment.is_none(), + "extraction should not happen without extraction client" + ); + } + } }