feat: staged retrieval pipeline (#30): issue-030.md
This commit is contained in:
548
implementation-plans/issue-030.md
Normal file
548
implementation-plans/issue-030.md
Normal file
@@ -0,0 +1,548 @@
|
||||
# Implementation Plan — Issue #30: Implement staged retrieval (coarse-to-fine, 4 stages)
|
||||
|
||||
## Metadata
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Issue | [#30](https://git.shahondin1624.de/llm-multiverse/llm-multiverse/issues/30) |
|
||||
| Title | Implement staged retrieval (coarse-to-fine, 4 stages) |
|
||||
| Milestone | Phase 4: Memory Service |
|
||||
| Labels | |
|
||||
| Status | `COMPLETED` |
|
||||
| Language | Rust |
|
||||
| Related Plans | [issue-011.md](issue-011.md), [issue-027.md](issue-027.md), [issue-028.md](issue-028.md), [issue-029.md](issue-029.md) |
|
||||
| Blocked by | #29 (completed) |
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- [ ] Stage 1: HNSW vector search returns top-K candidates
|
||||
- [ ] Stage 2: Metadata/keyword filter narrows results
|
||||
- [ ] Stage 3: Re-ranking scores and reorders results
|
||||
- [ ] Stage 4: Relevance threshold applied, low-confidence results dropped
|
||||
- [ ] Each stage is configurable (K, thresholds, filters)
|
||||
- [ ] Performance: full pipeline under 100ms for typical queries
|
||||
|
||||
## Architecture Analysis
|
||||
|
||||
### Service Context
|
||||
|
||||
This issue belongs to the **Memory Service** (Rust). It implements the core staged retrieval pipeline that powers the `QueryMemory` server-streaming RPC defined in `memory.proto`. The architecture document specifies a non-negotiable staged retrieval pattern (coarse-to-fine) with 4 stages:
|
||||
|
||||
1. Embed query, cosine similarity on `name_embeddings`, top 20
|
||||
2. Cosine similarity on `description_embeddings` of top 20, top 5
|
||||
3. Full corpus load of top 5, optional corpus embedding re-rank
|
||||
4. Correlation expansion: agent may request descriptions of `correlating_ids`
|
||||
|
||||
The proto defines:
|
||||
- `QueryMemoryRequest` with `context`, `query`, `memory_type` (tag filter), `limit`
|
||||
- `QueryMemoryResponse` (streamed) with `rank`, `entry`, `cosine_similarity`, `is_cached`, `cached_extracted_segment`
|
||||
- `MemoryEntry` with all fields including embeddings, tags, correlating_ids, provenance
|
||||
|
||||
### Existing Patterns
|
||||
|
||||
- **DuckDB access:** `DuckDbManager` wraps connection in `Mutex<Connection>` with `with_connection()` closure pattern (see `services/memory/src/db/mod.rs:84-90`).
|
||||
- **Embedding generation:** `EmbeddingClient` implements `EmbeddingGenerator` trait for mock-testable embedding (see `services/memory/src/embedding/mod.rs:100-109`). The `generate()` method returns `Vec<f32>` of `EMBEDDING_DIM` (768) dimensions.
|
||||
- **Embedding storage:** `store_embedding()` and `store_embeddings()` in `services/memory/src/embedding/store.rs` use `format_vector_literal()` to produce DuckDB `FLOAT[768]` array literals.
|
||||
- **Schema:** `embeddings` table has `(memory_id, embedding_type, vector FLOAT[768])` with HNSW index `idx_embeddings_hnsw`. The `embedding_type` column distinguishes `'name'`, `'description'`, `'corpus'` vectors. Filtering by type is done at query time (see `services/memory/src/db/schema.rs:59-60`).
|
||||
- **Vector similarity:** Existing test in `services/memory/src/db/schema.rs:407-485` demonstrates `array_cosine_similarity()` usage with `ORDER BY score DESC`.
|
||||
- **Service struct:** `MemoryServiceImpl` holds `Arc<DuckDbManager>` and `Option<Arc<Mutex<EmbeddingClient>>>` (see `services/memory/src/service.rs:17-22`).
|
||||
- **gRPC streaming:** `QueryMemoryStream` is typed as `tokio_stream::wrappers::ReceiverStream<Result<QueryMemoryResponse, Status>>` (see `services/memory/src/service.rs:46-47`).
|
||||
|
||||
### Dependencies
|
||||
|
||||
- **EmbeddingGenerator trait** (from issue #29) -- needed to embed the incoming query text before Stage 1.
|
||||
- **DuckDB VSS extension** -- provides `array_cosine_similarity()` function and HNSW indexing for efficient vector search.
|
||||
- **No new external crate dependencies** -- all required functionality is available via existing `duckdb`, `tonic`, `tokio-stream` dependencies.
|
||||
- **Proto stubs** -- `QueryMemoryRequest`, `QueryMemoryResponse`, `MemoryEntry`, `MemoryProvenance`, `SessionContext` are all generated.
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### 1. Types & Configuration
|
||||
|
||||
**Add retrieval configuration to `services/memory/src/config.rs`:**
|
||||
|
||||
```rust
|
||||
/// Configuration for the staged retrieval pipeline.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct RetrievalConfig {
|
||||
/// Stage 1: Number of candidates from HNSW name-embedding search (default: 20).
|
||||
#[serde(default = "default_stage1_top_k")]
|
||||
pub stage1_top_k: u32,
|
||||
|
||||
/// Stage 2: Number of candidates after description-embedding re-rank (default: 5).
|
||||
#[serde(default = "default_stage2_top_k")]
|
||||
pub stage2_top_k: u32,
|
||||
|
||||
/// Stage 4: Minimum cosine similarity score to include in final results (default: 0.3).
|
||||
#[serde(default = "default_relevance_threshold")]
|
||||
pub relevance_threshold: f32,
|
||||
}
|
||||
```
|
||||
|
||||
Add `retrieval: RetrievalConfig` field to the main `Config` struct with `#[serde(default)]`.
|
||||
|
||||
**Define retrieval pipeline types in a new `services/memory/src/retrieval/mod.rs`:**
|
||||
|
||||
```rust
|
||||
/// A candidate memory entry passing through the retrieval pipeline.
|
||||
/// Carries accumulated scores from each stage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalCandidate {
|
||||
pub memory_id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub corpus: String,
|
||||
pub tags: Vec<String>,
|
||||
pub correlating_ids: Vec<String>,
|
||||
pub provenance: i32,
|
||||
pub created_at: Option<chrono::NaiveDateTime>,
|
||||
pub last_accessed: Option<chrono::NaiveDateTime>,
|
||||
pub access_count: u32,
|
||||
/// Cosine similarity score from Stage 1 (name embedding).
|
||||
pub name_score: f32,
|
||||
/// Cosine similarity score from Stage 2 (description embedding).
|
||||
pub description_score: f32,
|
||||
/// Cosine similarity score from Stage 3 (corpus embedding re-rank).
|
||||
pub corpus_score: f32,
|
||||
/// Combined/final score after all stages.
|
||||
pub final_score: f32,
|
||||
}
|
||||
|
||||
/// Parameters controlling the staged retrieval pipeline.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievalParams {
|
||||
/// Maximum candidates from Stage 1 (name HNSW search).
|
||||
pub stage1_top_k: u32,
|
||||
/// Maximum candidates from Stage 2 (description re-rank).
|
||||
pub stage2_top_k: u32,
|
||||
/// Minimum score threshold for Stage 4 cutoff.
|
||||
pub relevance_threshold: f32,
|
||||
/// Tag filter (from QueryMemoryRequest.memory_type).
|
||||
pub tag_filter: Option<String>,
|
||||
/// Final result limit (from QueryMemoryRequest.limit, defaults to 5).
|
||||
pub result_limit: u32,
|
||||
}
|
||||
|
||||
/// Errors specific to the retrieval pipeline.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RetrievalError {
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] crate::db::DbError),
|
||||
|
||||
#[error("embedding generation failed: {0}")]
|
||||
Embedding(#[from] crate::embedding::EmbeddingError),
|
||||
|
||||
#[error("no embedding client configured")]
|
||||
NoEmbeddingClient,
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Core Logic
|
||||
|
||||
**Create `services/memory/src/retrieval/pipeline.rs` -- The 4-stage retrieval pipeline:**
|
||||
|
||||
```rust
|
||||
/// Execute the full staged retrieval pipeline.
|
||||
///
|
||||
/// Stage 1: Embed query -> cosine similarity on name_embeddings -> top stage1_top_k
|
||||
/// Stage 2: Cosine similarity on description_embeddings of Stage 1 results -> top stage2_top_k
|
||||
/// Stage 3: Load full corpus of Stage 2 results, re-rank by corpus_embedding similarity
|
||||
/// Stage 4: Apply relevance threshold, drop low-confidence results
|
||||
pub async fn execute_pipeline(
|
||||
db: &DuckDbManager,
|
||||
embedding_client: &dyn EmbeddingGenerator,
|
||||
context: &SessionContext,
|
||||
query: &str,
|
||||
params: &RetrievalParams,
|
||||
) -> Result<Vec<RetrievalCandidate>, RetrievalError>;
|
||||
```
|
||||
|
||||
**Stage 1 -- HNSW Vector Search (name embeddings):**
|
||||
|
||||
Create `services/memory/src/retrieval/stage1.rs`:
|
||||
|
||||
```rust
|
||||
/// Execute Stage 1: HNSW vector search on name embeddings.
|
||||
///
|
||||
/// Embeds the query text via the EmbeddingGenerator, then runs a cosine similarity
|
||||
/// query against the `embeddings` table filtered to `embedding_type = 'name'`.
|
||||
/// Returns up to `top_k` candidates ordered by descending cosine similarity.
|
||||
///
|
||||
/// SQL pattern:
|
||||
/// SELECT e.memory_id, array_cosine_similarity(e.vector, <query_vector>::FLOAT[768]) AS score
|
||||
/// FROM embeddings e
|
||||
/// WHERE e.embedding_type = 'name'
|
||||
/// ORDER BY score DESC
|
||||
/// LIMIT ?
|
||||
pub fn search_by_name_embedding(
|
||||
conn: &Connection,
|
||||
query_vector: &[f32],
|
||||
top_k: u32,
|
||||
) -> Result<Vec<(String, f32)>, DbError>;
|
||||
```
|
||||
|
||||
**Stage 2 -- Description Embedding Re-rank with Optional Tag Filter:**
|
||||
|
||||
Create `services/memory/src/retrieval/stage2.rs`:
|
||||
|
||||
```rust
|
||||
/// Execute Stage 2: Re-rank Stage 1 candidates by description embedding similarity.
|
||||
///
|
||||
/// For each candidate from Stage 1, compute cosine similarity between the query vector
|
||||
/// and the candidate's description embedding. Optionally filter by tag (memory_type).
|
||||
/// Returns the top `top_k` candidates ordered by description score.
|
||||
///
|
||||
/// SQL pattern (per candidate batch):
|
||||
/// SELECT e.memory_id, array_cosine_similarity(e.vector, <query_vector>::FLOAT[768]) AS score
|
||||
/// FROM embeddings e
|
||||
/// WHERE e.embedding_type = 'description'
|
||||
/// AND e.memory_id IN (<stage1_ids>)
|
||||
/// ORDER BY score DESC
|
||||
/// LIMIT ?
|
||||
///
|
||||
/// Tag filter (when memory_type is set):
|
||||
/// AND e.memory_id IN (SELECT memory_id FROM memory_tags WHERE tag = ?)
|
||||
pub fn rerank_by_description(
|
||||
conn: &Connection,
|
||||
query_vector: &[f32],
|
||||
candidate_ids: &[String],
|
||||
tag_filter: Option<&str>,
|
||||
top_k: u32,
|
||||
) -> Result<Vec<(String, f32)>, DbError>;
|
||||
```
|
||||
|
||||
**Stage 3 -- Corpus Load and Re-rank:**
|
||||
|
||||
Create `services/memory/src/retrieval/stage3.rs`:
|
||||
|
||||
```rust
|
||||
/// Execute Stage 3: Load full memory entries and re-rank by corpus embedding.
|
||||
///
|
||||
/// Loads full memory rows (memories + tags + correlations) for the Stage 2 candidates.
|
||||
/// Computes cosine similarity between the query vector and each candidate's corpus
|
||||
/// embedding. Combines name, description, and corpus scores into a final score using
|
||||
/// a weighted average:
|
||||
/// final_score = 0.3 * name_score + 0.3 * description_score + 0.4 * corpus_score
|
||||
///
|
||||
/// Returns candidates as fully populated `RetrievalCandidate` structs sorted by final_score DESC.
|
||||
pub fn load_and_rerank(
|
||||
conn: &Connection,
|
||||
query_vector: &[f32],
|
||||
candidate_ids: &[String],
|
||||
name_scores: &HashMap<String, f32>,
|
||||
description_scores: &HashMap<String, f32>,
|
||||
) -> Result<Vec<RetrievalCandidate>, DbError>;
|
||||
```
|
||||
|
||||
**Stage 4 -- Relevance Threshold Cutoff:**
|
||||
|
||||
Create `services/memory/src/retrieval/stage4.rs`:
|
||||
|
||||
```rust
|
||||
/// Execute Stage 4: Apply relevance threshold and limit results.
|
||||
///
|
||||
/// Filters candidates whose `final_score` is below `threshold`.
|
||||
/// Truncates to `limit` results. Updates `last_accessed` and increments
|
||||
/// `access_count` for returned entries.
|
||||
pub fn apply_threshold(
|
||||
candidates: Vec<RetrievalCandidate>,
|
||||
threshold: f32,
|
||||
limit: u32,
|
||||
) -> Vec<RetrievalCandidate>;
|
||||
|
||||
/// Update access tracking for retrieved memories.
|
||||
///
|
||||
/// Increments `access_count` and sets `last_accessed` to current timestamp
|
||||
/// for all memory IDs in the result set.
|
||||
pub fn update_access_tracking(
|
||||
conn: &Connection,
|
||||
memory_ids: &[String],
|
||||
) -> Result<(), DbError>;
|
||||
```
|
||||
|
||||
**Pipeline orchestration in `services/memory/src/retrieval/pipeline.rs`:**
|
||||
|
||||
The `execute_pipeline` function coordinates the stages:
|
||||
|
||||
1. Generate query embedding via `EmbeddingGenerator::generate()`.
|
||||
2. Call `stage1::search_by_name_embedding()` with the query vector and `params.stage1_top_k`.
|
||||
3. If no results, return empty.
|
||||
4. Call `stage2::rerank_by_description()` with Stage 1 candidate IDs, tag filter, and `params.stage2_top_k`.
|
||||
5. If no results after filtering, return empty.
|
||||
6. Call `stage3::load_and_rerank()` with Stage 2 candidate IDs and accumulated scores.
|
||||
7. Call `stage4::apply_threshold()` with `params.relevance_threshold` and `params.result_limit`.
|
||||
8. Call `stage4::update_access_tracking()` for returned memory IDs.
|
||||
9. Return the final sorted candidates.
|
||||
|
||||
### 3. gRPC Handler Wiring
|
||||
|
||||
**Update `services/memory/src/service.rs` -- Implement `query_memory`:**
|
||||
|
||||
Replace the `Unimplemented` stub with the actual pipeline:
|
||||
|
||||
```rust
|
||||
async fn query_memory(
|
||||
&self,
|
||||
request: Request<QueryMemoryRequest>,
|
||||
) -> Result<Response<Self::QueryMemoryStream>, Status> {
|
||||
// ... existing validation ...
|
||||
|
||||
let embedding_client = self.embedding_client.as_ref()
|
||||
.ok_or_else(|| Status::failed_precondition("embedding client not configured"))?;
|
||||
|
||||
let params = RetrievalParams {
|
||||
stage1_top_k: self.retrieval_config.stage1_top_k,
|
||||
stage2_top_k: self.retrieval_config.stage2_top_k,
|
||||
relevance_threshold: self.retrieval_config.relevance_threshold,
|
||||
tag_filter: if req.memory_type.is_empty() { None } else { Some(req.memory_type.clone()) },
|
||||
result_limit: if req.limit == 0 { 5 } else { req.limit },
|
||||
};
|
||||
|
||||
// Run the pipeline (embedding generation is async, DB queries are sync in with_connection)
|
||||
let mut client = embedding_client.lock().await;
|
||||
let candidates = retrieval::pipeline::execute_pipeline(
|
||||
&self.db, &*client, &ctx, &req.query, ¶ms,
|
||||
).await.map_err(|e| match e {
|
||||
RetrievalError::NoEmbeddingClient => Status::failed_precondition("embedding client not configured"),
|
||||
RetrievalError::Embedding(e) => Status::unavailable(format!("embedding error: {e}")),
|
||||
RetrievalError::Database(e) => Status::internal(format!("database error: {e}")),
|
||||
})?;
|
||||
|
||||
// Stream results via channel
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(candidates.len().max(1));
|
||||
tokio::spawn(async move {
|
||||
for (rank, candidate) in candidates.into_iter().enumerate() {
|
||||
let response = QueryMemoryResponse {
|
||||
rank: (rank + 1) as u32,
|
||||
entry: Some(candidate_to_memory_entry(&candidate)),
|
||||
cosine_similarity: candidate.final_score,
|
||||
is_cached: false,
|
||||
cached_extracted_segment: None,
|
||||
};
|
||||
if tx.send(Ok(response)).await.is_err() {
|
||||
break; // Client disconnected
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Response::new(ReceiverStream::new(rx)))
|
||||
}
|
||||
```
|
||||
|
||||
**Add `RetrievalConfig` to `MemoryServiceImpl`:**
|
||||
|
||||
```rust
|
||||
pub struct MemoryServiceImpl {
|
||||
db: Arc<DuckDbManager>,
|
||||
embedding_client: Option<Arc<Mutex<EmbeddingClient>>>,
|
||||
retrieval_config: RetrievalConfig,
|
||||
}
|
||||
|
||||
impl MemoryServiceImpl {
|
||||
pub fn new(db: Arc<DuckDbManager>, retrieval_config: RetrievalConfig) -> Self {
|
||||
Self {
|
||||
db,
|
||||
embedding_client: None,
|
||||
retrieval_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Helper function to convert `RetrievalCandidate` to protobuf `MemoryEntry`:**
|
||||
|
||||
```rust
|
||||
fn candidate_to_memory_entry(candidate: &RetrievalCandidate) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
id: candidate.memory_id.clone(),
|
||||
name: candidate.name.clone(),
|
||||
description: candidate.description.clone(),
|
||||
tags: candidate.tags.clone(),
|
||||
correlating_ids: candidate.correlating_ids.clone(),
|
||||
corpus: candidate.corpus.clone(),
|
||||
name_embedding: vec![], // Embeddings not sent over the wire in query responses
|
||||
description_embedding: vec![],
|
||||
corpus_embedding: vec![],
|
||||
created_at: candidate.created_at.map(timestamp_to_proto),
|
||||
last_accessed: candidate.last_accessed.map(timestamp_to_proto),
|
||||
access_count: candidate.access_count,
|
||||
provenance: candidate.provenance,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Service Integration
|
||||
|
||||
**Update `services/memory/src/main.rs` -- Pass retrieval config:**
|
||||
|
||||
```rust
|
||||
let retrieval_config = config.retrieval.clone();
|
||||
let mut memory_service = MemoryServiceImpl::new(db, retrieval_config);
|
||||
```
|
||||
|
||||
**No new cross-service dependencies.** The embedding client (Model Gateway) is already wired from issue #29. The retrieval pipeline uses it to embed the query text.
|
||||
|
||||
**Error mapping:** `RetrievalError` variants map to gRPC status codes:
|
||||
- `RetrievalError::NoEmbeddingClient` -> `Status::failed_precondition`
|
||||
- `RetrievalError::Embedding(_)` -> `Status::unavailable`
|
||||
- `RetrievalError::Database(_)` -> `Status::internal`
|
||||
|
||||
### 5. Tests
|
||||
|
||||
**Unit tests for each stage module:**
|
||||
|
||||
`services/memory/src/retrieval/stage1.rs`:
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_stage1_returns_top_k` | Insert 30 memories with name embeddings, query returns exactly `top_k` (20) |
|
||||
| `test_stage1_ordering` | Results are ordered by descending cosine similarity |
|
||||
| `test_stage1_empty_table` | Returns empty vec when no embeddings exist |
|
||||
| `test_stage1_fewer_than_k` | When fewer than K entries exist, returns all available |
|
||||
|
||||
`services/memory/src/retrieval/stage2.rs`:
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_stage2_narrows_candidates` | From 20 Stage 1 candidates, returns top 5 by description score |
|
||||
| `test_stage2_tag_filter` | Only candidates matching tag filter survive |
|
||||
| `test_stage2_no_tag_filter` | Without tag filter, all candidates are considered |
|
||||
| `test_stage2_empty_after_filter` | Returns empty when tag filter matches no candidates |
|
||||
|
||||
`services/memory/src/retrieval/stage3.rs`:
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_stage3_loads_full_entries` | Returned candidates have populated corpus, tags, correlating_ids |
|
||||
| `test_stage3_weighted_scoring` | Final score correctly combines name (0.3), description (0.3), corpus (0.4) scores |
|
||||
| `test_stage3_sorted_by_final_score` | Results sorted by final_score descending |
|
||||
|
||||
`services/memory/src/retrieval/stage4.rs`:
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_stage4_threshold_removes_low_scores` | Candidates below threshold are dropped |
|
||||
| `test_stage4_limit_truncates` | When more candidates than limit pass threshold, truncates to limit |
|
||||
| `test_stage4_all_above_threshold` | All candidates pass when all scores exceed threshold |
|
||||
| `test_stage4_all_below_threshold` | Returns empty when all candidates are below threshold |
|
||||
|
||||
`services/memory/src/retrieval/pipeline.rs`:
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_full_pipeline_end_to_end` | Insert test data, run full pipeline with `MockEmbeddingGenerator`, verify 4-stage flow |
|
||||
| `test_pipeline_no_results` | Query against empty database returns empty vec |
|
||||
| `test_pipeline_with_tag_filter` | Only tagged memories survive Stage 2 |
|
||||
| `test_pipeline_updates_access_tracking` | After pipeline, `last_accessed` and `access_count` are updated |
|
||||
|
||||
`services/memory/src/retrieval/mod.rs`:
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_retrieval_params_from_config` | `RetrievalParams` correctly constructed from `RetrievalConfig` |
|
||||
|
||||
**Integration tests in `services/memory/src/service.rs`:**
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_query_memory_returns_streamed_results` | Full gRPC handler test with mock embedding client and pre-populated DB |
|
||||
| `test_query_memory_no_embedding_client` | Returns `failed_precondition` when no embedding client configured |
|
||||
| `test_query_memory_respects_limit` | Query with `limit=2` returns at most 2 results |
|
||||
| `test_query_memory_tag_filter` | Query with `memory_type` set filters by tag |
|
||||
|
||||
**Mocking strategy:**
|
||||
- Use `MockEmbeddingGenerator` (from issue #29, `services/memory/src/embedding/mod.rs:236-284`) for all retrieval tests. It returns deterministic vectors based on text length, which is sufficient for testing pipeline ordering and threshold logic.
|
||||
- Use `DuckDbManager::in_memory()` for all DB operations.
|
||||
- Pre-populate the in-memory DB with test memories, embeddings, tags, and correlations using direct SQL inserts.
|
||||
|
||||
**Performance test (optional, not blocking):**
|
||||
|
||||
| Test Case | Description |
|
||||
|---|---|
|
||||
| `test_pipeline_performance` | Insert 1000 memories with embeddings, verify pipeline completes in under 100ms |
|
||||
|
||||
### Cargo Dependencies
|
||||
|
||||
No new crate dependencies required. All functionality is available via:
|
||||
- `duckdb` (vector similarity queries via `array_cosine_similarity`)
|
||||
- `tokio` / `tokio-stream` (async pipeline, streaming response)
|
||||
- `chrono` (timestamp handling for access tracking)
|
||||
- `thiserror` (error types)
|
||||
|
||||
### Trait Implementations
|
||||
|
||||
- `From<DbError> for RetrievalError` -- convert DB errors to retrieval errors
|
||||
- `From<EmbeddingError> for RetrievalError` -- convert embedding errors to retrieval errors
|
||||
|
||||
### Error Types
|
||||
|
||||
- `RetrievalError` -- enum covering database, embedding, and configuration errors (see Types section above)
|
||||
|
||||
## Files to Create/Modify
|
||||
|
||||
| File | Action | Purpose |
|
||||
|---|---|---|
|
||||
| `services/memory/src/config.rs` | Modify | Add `RetrievalConfig` struct with stage1_top_k, stage2_top_k, relevance_threshold; add `retrieval` field to `Config` |
|
||||
| `services/memory/src/lib.rs` | Modify | Add `pub mod retrieval;` |
|
||||
| `services/memory/src/retrieval/mod.rs` | Create | `RetrievalCandidate`, `RetrievalParams`, `RetrievalError`, module declarations |
|
||||
| `services/memory/src/retrieval/pipeline.rs` | Create | `execute_pipeline()` -- orchestrates all 4 stages |
|
||||
| `services/memory/src/retrieval/stage1.rs` | Create | `search_by_name_embedding()` -- HNSW vector search on name embeddings |
|
||||
| `services/memory/src/retrieval/stage2.rs` | Create | `rerank_by_description()` -- description embedding re-rank with tag filter |
|
||||
| `services/memory/src/retrieval/stage3.rs` | Create | `load_and_rerank()` -- full entry load, corpus embedding re-rank, weighted scoring |
|
||||
| `services/memory/src/retrieval/stage4.rs` | Create | `apply_threshold()`, `update_access_tracking()` -- threshold cutoff and access bookkeeping |
|
||||
| `services/memory/src/service.rs` | Modify | Implement `query_memory` with retrieval pipeline; add `retrieval_config` to `MemoryServiceImpl`; add `candidate_to_memory_entry()` helper |
|
||||
| `services/memory/src/main.rs` | Modify | Pass `RetrievalConfig` to `MemoryServiceImpl::new()` |
|
||||
|
||||
## Risks and Edge Cases
|
||||
|
||||
- **HNSW index on empty table:** DuckDB VSS may fail to use the HNSW index if it was deferred during schema creation. Stage 1 should call `ensure_hnsw_index()` before querying, or fall back to a sequential scan if the index does not exist. The existing `ensure_hnsw_index()` in `services/memory/src/db/schema.rs:143-146` handles this.
|
||||
- **Missing embeddings:** A memory entry may exist in the `memories` table without corresponding embeddings (e.g., if embedding generation failed). Stage 1 will simply not return such entries (no join hit in `embeddings`). This is the correct behavior -- entries without embeddings are not retrievable via vector search.
|
||||
- **Query vector dimension mismatch:** If the embedding client returns a vector with the wrong dimension, the `EmbeddingGenerator::generate()` method already validates dimensions (see `services/memory/src/embedding/mod.rs:224-229`). This error propagates as `RetrievalError::Embedding`.
|
||||
- **DuckDB `array_cosine_similarity` with zero vectors:** Empty text fields produce zero vectors (from issue #29). Cosine similarity with a zero vector is undefined (division by zero). DuckDB may return `NaN` or `NULL`. Stage 2/3 should handle this by treating `NaN`/`NULL` scores as 0.0.
|
||||
- **Large candidate sets in Stage 2:** The `IN (...)` clause with 20 IDs is well within SQL limits. For future scaling beyond 100s of candidates, consider using a temporary table instead.
|
||||
- **Concurrent access:** The `DuckDbManager` uses `Mutex<Connection>` which serializes all DB access. The retrieval pipeline holds the lock for the duration of each stage's DB query (not across stages). Between stages, the lock is released, allowing other operations (e.g., writes) to proceed. The async embedding call in `execute_pipeline` happens outside the lock.
|
||||
- **Performance target (100ms):** The target is achievable for typical workloads (hundreds to low thousands of memories). HNSW indexing provides sublinear search in Stage 1. Stages 2-4 operate on small candidate sets (20 -> 5 -> final). The main latency risk is the embedding generation call (Stage 0), which depends on Model Gateway / Ollama inference speed for nomic-embed-text. This is typically 5-20ms for short queries.
|
||||
- **Weighted score formula:** The weights (0.3/0.3/0.4) are initial values. Consider making these configurable in `RetrievalConfig` for tuning. Start with fixed weights for simplicity.
|
||||
|
||||
## Deviation Log
|
||||
|
||||
| Deviation | Reason |
|
||||
|---|---|
|
||||
| Cherry-picked issue #29 (embedding integration) commits onto this branch | Issue #29 is completed but not yet merged to main. The retrieval pipeline depends on `EmbeddingGenerator` trait and `EmbeddingClient` from #29. |
|
||||
| Made `format_vector_literal` in `embedding/store.rs` public | The retrieval stage modules need to format query vectors as DuckDB literals for SQL queries. Reusing the existing helper avoids duplication. |
|
||||
| Used `CAST(timestamp AS VARCHAR)` + string parsing for timestamps | DuckDB's Rust driver does not implement `FromSql` for `chrono::NaiveDateTime`. Casting to VARCHAR and parsing with chrono is the reliable workaround. |
|
||||
| Changed `MemoryServiceImpl::new()` to accept `RetrievalConfig` parameter | Plan specified adding `retrieval_config` field but the constructor signature change was not explicitly called out. Required to wire config through. |
|
||||
| Replaced `test_query_returns_unimplemented` with `test_query_memory_no_embedding_client` | The query_memory endpoint now returns `FailedPrecondition` instead of `Unimplemented` when no embedding client is configured, since the pipeline is now implemented. |
|
||||
| Changed `execute_pipeline` from async (accepting `EmbeddingGenerator`) to sync (accepting pre-computed `query_vector: &[f32]`) | Retry fix #2: Embedding client mutex was held too long. Moving embedding generation into the caller (service.rs) allows the lock to be dropped before running pipeline stages. |
|
||||
| Boxed `EmbeddingError` in `RetrievalError::Embedding` variant | Clippy `result_large_err` lint flagged the `EmbeddingError` variant as too large (176+ bytes). Boxing resolves this. |
|
||||
|
||||
## Retry Instructions
|
||||
|
||||
### Failure Summary (Attempt 1)
|
||||
|
||||
**Quality Gates:**
|
||||
- Build: PASS
|
||||
- Lint (clippy): PASS
|
||||
- Tests: PASS (143 passed, 0 failed)
|
||||
- Coverage: PASS (all retrieval modules >95%; service.rs at 87.7% acceptable for gRPC wiring)
|
||||
|
||||
**Code Review: REQUEST_CHANGES**
|
||||
|
||||
### Required Fixes
|
||||
|
||||
1. **MAJOR — SQL injection in `services/memory/src/retrieval/stage2.rs` line ~49:**
|
||||
- The `tag_filter` value originates from the gRPC request's `memory_type` field (user-controlled input) and is interpolated directly into SQL with `format!("... WHERE tag = '{tag}'")`
|
||||
- A crafted `memory_type` like `' OR 1=1; --` could manipulate the query
|
||||
- **Fix:** Use a parameterized query placeholder (`?`) and pass the tag value as a bind parameter via `duckdb::params![]`. This is the same pattern used for other parameters elsewhere in the codebase.
|
||||
|
||||
2. **MINOR — Embedding client mutex held too long in `services/memory/src/service.rs` line ~118:**
|
||||
- The `embedding_client` mutex lock is held for the entire `execute_pipeline()` call including all 4 DB stages
|
||||
- Only the initial embedding generation (Stage 0) needs the client
|
||||
- **Fix:** Lock the client, call `generate()` to get the query vector, drop the lock, then pass the vector to the remaining pipeline stages. This may require refactoring `execute_pipeline()` to accept a pre-computed query vector instead of the embedding client.
|
||||
|
||||
3. **MINOR — Missing gRPC streaming integration test in `services/memory/src/service.rs`:**
|
||||
- Add a test `test_query_memory_returns_streamed_results` that populates the DB with test data, attaches a `MockEmbeddingGenerator`, calls `query_memory`, and collects the stream to verify end-to-end wiring.
|
||||
|
||||
4. **After fixing, run `cargo test --workspace` to verify all tests pass.**
|
||||
|
||||
5. **Then run `cargo clippy --workspace -- -D warnings` to verify no new warnings.**
|
||||
Reference in New Issue
Block a user