Merge pull request 'fix: resolve tech debt from issue #31 review (#120)' (#130) from feature/issue-120-tech-debt-31-review into main

This commit was merged in pull request #130.
This commit is contained in:
2026-03-10 09:39:07 +01:00
7 changed files with 290 additions and 190 deletions

40
Cargo.lock generated
View File

@@ -680,6 +680,21 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "futures"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.32"
@@ -696,12 +711,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
[[package]]
name = "futures-executor"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-macro"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "futures-sink"
version = "0.3.32"
@@ -720,8 +757,10 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
@@ -1367,6 +1406,7 @@ dependencies = [
"anyhow",
"chrono",
"duckdb",
"futures",
"llm-multiverse-proto",
"prost",
"prost-types",

View File

@@ -34,6 +34,7 @@
| #114 | Tech debt: minor findings from issue #28 review | Phase 4 | `COMPLETED` | Rust | [issue-114.md](issue-114.md) |
| #116 | Tech debt: minor findings from issue #29 review | Phase 4 | `COMPLETED` | Rust | [issue-116.md](issue-116.md) |
| #118 | Tech debt: minor findings from issue #30 review | Phase 4 | `COMPLETED` | Rust | [issue-118.md](issue-118.md) |
| #120 | Tech debt: minor findings from issue #31 review | Phase 4 | `COMPLETED` | Rust | [issue-120.md](issue-120.md) |
## Status Legend

View File

@@ -0,0 +1,43 @@
# Issue #120: Tech debt: minor findings from issue #31 review
## Summary
Address three tech debt items from the extraction step review:
1. **Add [UNTRUSTED CONTENT] marker for external-provenance memories** - Wrap corpus text in `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers when the memory has external provenance (provenance == 2), providing prompt injection mitigation.
2. **Parallel extraction for batch candidates** - Change `extract_batch` from sequential to parallel processing using `futures::future::join_all`, reducing latency for multiple candidates.
3. **Consolidate duplicate mock gateway test helpers** - Merge `FullMockGateway` (service.rs) and `TestGateway` (service.rs query_memory test) into a shared `test_helpers` module.
## Item 1: [UNTRUSTED CONTENT] markers
### Approach
Add a `provenance` parameter to `build_extraction_prompt`. When `provenance == 2` (external), wrap the corpus block in `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers. Update `extract` and `extract_batch` to pass the candidate's provenance field through.
### Files changed
- `services/memory/src/extraction/prompt.rs` - Add provenance param to `build_extraction_prompt`
- `services/memory/src/extraction/mod.rs` - Pass provenance through `extract` and `extract_batch`
## Item 2: Parallel extraction
### Approach
Replace the sequential `for` loop in `extract_batch` with `futures::future::join_all` to process all candidates concurrently. The `ExtractionClient` wraps a `ModelGatewayServiceClient` which supports concurrent calls via `clone()`.
### Files changed
- `services/memory/Cargo.toml` - Add `futures` dependency
- `services/memory/src/extraction/mod.rs` - Rewrite `extract_batch` with `join_all`
## Item 3: Consolidate mock gateway helpers
### Approach
Create a `test_helpers` module at the top of the `#[cfg(test)]` section in `service.rs` containing a single configurable `MockGateway` struct that handles both embedding and inference RPCs. Replace `FullMockGateway` and the inline `TestGateway` with this unified type.
### Files changed
- `services/memory/src/service.rs` - Consolidate into shared `test_helpers` module, update test functions

View File

@@ -21,6 +21,7 @@ tokio-stream = "0.1"
duckdb = { version = "1", features = ["bundled"] }
chrono = "0.4"
regex = "1"
futures = "0.3"
[dev-dependencies]
tempfile = "3"

View File

@@ -67,6 +67,9 @@ impl ExtractionClient {
/// 4. Include a confidence score (0.0-1.0)
///
/// Uses `TaskComplexity::Simple` for lightweight model routing.
///
/// The `provenance` parameter controls whether the corpus is wrapped in
/// `[UNTRUSTED CONTENT]` markers (when `provenance == 2`, i.e. external).
pub async fn extract(
&self,
context: &SessionContext,
@@ -74,6 +77,7 @@ impl ExtractionClient {
corpus: &str,
memory_name: &str,
memory_id: &str,
provenance: i32,
) -> Result<ExtractionResult, ExtractionError> {
// Skip extraction for empty corpus
if corpus.is_empty() {
@@ -84,7 +88,8 @@ impl ExtractionClient {
});
}
let extraction_prompt = prompt::build_extraction_prompt(query, corpus, memory_name);
let extraction_prompt =
prompt::build_extraction_prompt(query, corpus, memory_name, provenance);
let request = InferenceRequest {
params: Some(InferenceParams {
@@ -104,9 +109,10 @@ impl ExtractionClient {
Ok(prompt::parse_extraction_response(&response.text, memory_id))
}
/// Extract segments for multiple candidates.
/// Extract segments for multiple candidates in parallel.
///
/// Processes candidates sequentially to avoid overloading the gateway.
/// Uses `futures::future::join_all` to process all candidates concurrently,
/// reducing total latency compared to sequential processing.
/// On failure for a single candidate, that candidate's extraction is
/// skipped (returns the full corpus as fallback) and processing continues.
pub async fn extract_batch(
@@ -115,36 +121,38 @@ impl ExtractionClient {
query: &str,
candidates: &[RetrievalCandidate],
) -> Vec<ExtractionResult> {
let mut results = Vec::with_capacity(candidates.len());
for candidate in candidates {
match self
.extract(
context,
query,
&candidate.corpus,
&candidate.name,
&candidate.memory_id,
)
.await
{
Ok(result) => results.push(result),
Err(e) => {
tracing::warn!(
memory_id = %candidate.memory_id,
error = %e,
"Extraction failed for candidate, using full corpus as fallback"
);
results.push(ExtractionResult {
memory_id: candidate.memory_id.clone(),
segment: candidate.corpus.clone(),
confidence: 0.0,
});
let futures: Vec<_> = candidates
.iter()
.map(|candidate| async {
match self
.extract(
context,
query,
&candidate.corpus,
&candidate.name,
&candidate.memory_id,
candidate.provenance,
)
.await
{
Ok(result) => result,
Err(e) => {
tracing::warn!(
memory_id = %candidate.memory_id,
error = %e,
"Extraction failed for candidate, using full corpus as fallback"
);
ExtractionResult {
memory_id: candidate.memory_id.clone(),
segment: candidate.corpus.clone(),
confidence: 0.0,
}
}
}
}
}
})
.collect();
results
futures::future::join_all(futures).await
}
}
@@ -302,7 +310,7 @@ mod tests {
};
let result = client
.extract(&ctx, "test query", "full corpus text here", "mem name", "mem-1")
.extract(&ctx, "test query", "full corpus text here", "mem name", "mem-1", 1)
.await
.unwrap();
@@ -329,7 +337,7 @@ mod tests {
};
let result = client
.extract(&ctx, "test query", "corpus", "name", "mem-1")
.extract(&ctx, "test query", "corpus", "name", "mem-1", 1)
.await;
assert!(result.is_err());
@@ -359,7 +367,7 @@ mod tests {
};
let result = client
.extract(&ctx, "test query", "", "mem name", "mem-1")
.extract(&ctx, "test query", "", "mem name", "mem-1", 1)
.await
.unwrap();

View File

@@ -20,7 +20,22 @@ pub struct ExtractionResult {
/// The prompt instructs the model to extract only the relevant segment
/// from the corpus, given the query context. Output format is structured
/// to enable reliable parsing.
pub fn build_extraction_prompt(query: &str, corpus: &str, memory_name: &str) -> String {
///
/// When `provenance` is `2` (external), the corpus is wrapped in
/// `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers to mitigate
/// prompt injection from external-provenance memories.
pub fn build_extraction_prompt(
query: &str,
corpus: &str,
memory_name: &str,
provenance: i32,
) -> String {
let corpus_block = if provenance == 2 {
format!("[UNTRUSTED CONTENT]\n{corpus}\n[END UNTRUSTED CONTENT]")
} else {
corpus.to_string()
};
format!(
"Given the following search query and memory content, extract ONLY the segment \
of the content that is most relevant to the query. Return a concise, focused \
@@ -31,7 +46,7 @@ Memory name: {memory_name}\n\
\n\
Content:\n\
---\n\
{corpus}\n\
{corpus_block}\n\
---\n\
\n\
Respond in this exact format:\n\
@@ -99,22 +114,43 @@ mod tests {
#[test]
fn test_build_extraction_prompt_contains_query() {
let prompt = build_extraction_prompt("how to sort arrays", "some corpus", "mem name");
let prompt = build_extraction_prompt("how to sort arrays", "some corpus", "mem name", 1);
assert!(prompt.contains("how to sort arrays"));
}
#[test]
fn test_build_extraction_prompt_contains_corpus() {
let prompt = build_extraction_prompt("query", "the corpus content here", "mem name");
let prompt = build_extraction_prompt("query", "the corpus content here", "mem name", 1);
assert!(prompt.contains("the corpus content here"));
}
#[test]
fn test_build_extraction_prompt_contains_memory_name() {
let prompt = build_extraction_prompt("query", "corpus", "Rust Sort Algorithm");
let prompt = build_extraction_prompt("query", "corpus", "Rust Sort Algorithm", 1);
assert!(prompt.contains("Rust Sort Algorithm"));
}
#[test]
fn test_build_extraction_prompt_internal_no_untrusted_marker() {
let prompt = build_extraction_prompt("query", "corpus text", "mem", 1);
assert!(!prompt.contains("[UNTRUSTED CONTENT]"));
assert!(!prompt.contains("[END UNTRUSTED CONTENT]"));
}
#[test]
fn test_build_extraction_prompt_external_has_untrusted_marker() {
let prompt = build_extraction_prompt("query", "external corpus", "mem", 2);
assert!(prompt.contains("[UNTRUSTED CONTENT]"));
assert!(prompt.contains("[END UNTRUSTED CONTENT]"));
assert!(prompt.contains("external corpus"));
}
#[test]
fn test_build_extraction_prompt_unspecified_no_untrusted_marker() {
let prompt = build_extraction_prompt("query", "corpus text", "mem", 0);
assert!(!prompt.contains("[UNTRUSTED CONTENT]"));
}
#[test]
fn test_parse_extraction_response_valid() {
let response = "SEGMENT: The quicksort algorithm divides the array.\nCONFIDENCE: 0.8";

View File

@@ -463,6 +463,113 @@ mod tests {
use crate::config::{CacheConfig, ExtractionConfig, ProvenanceConfig, RetrievalConfig};
use llm_multiverse_proto::llm_multiverse::v1::{MemoryEntry, SessionContext};
/// Shared mock Model Gateway for tests that need embedding and/or inference RPCs.
///
/// Consolidates `FullMockGateway` and the inline `TestGateway` that were
/// previously duplicated across test functions.
mod test_helpers {
use crate::db::schema::EMBEDDING_DIM;
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
ModelGatewayService, ModelGatewayServiceServer,
};
use llm_multiverse_proto::llm_multiverse::v1::{
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
StreamInferenceRequest, StreamInferenceResponse,
};
use std::net::SocketAddr;
use tonic::{Request, Response, Status};
/// A configurable mock gateway that handles embedding and inference RPCs.
///
/// - `inference_response`: If `Some`, inference calls return this text.
/// If `None`, inference calls return `Unimplemented`.
/// - Embedding calls always return a vector of `0.10` values matching
/// `EMBEDDING_DIM`.
pub struct MockGateway {
pub inference_response: Option<String>,
}
impl MockGateway {
/// Create a gateway that supports embeddings but not inference.
pub fn embedding_only() -> Self {
Self {
inference_response: None,
}
}
/// Create a gateway that supports both embeddings and inference.
pub fn full(inference_text: &str) -> Self {
Self {
inference_response: Some(inference_text.to_string()),
}
}
}
#[tonic::async_trait]
impl ModelGatewayService for MockGateway {
type StreamInferenceStream =
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
async fn stream_inference(
&self,
_request: Request<StreamInferenceRequest>,
) -> Result<Response<Self::StreamInferenceStream>, Status> {
Err(Status::unimplemented("not needed"))
}
async fn inference(
&self,
_request: Request<InferenceRequest>,
) -> Result<Response<InferenceResponse>, Status> {
match &self.inference_response {
Some(text) => Ok(Response::new(InferenceResponse {
text: text.clone(),
finish_reason: "stop".to_string(),
tokens_used: 30,
})),
None => Err(Status::unimplemented("not needed")),
}
}
async fn generate_embedding(
&self,
_request: Request<GenerateEmbeddingRequest>,
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
let embedding = vec![0.10_f32; EMBEDDING_DIM];
Ok(Response::new(GenerateEmbeddingResponse {
embedding,
dimensions: EMBEDDING_DIM as u32,
}))
}
async fn is_model_ready(
&self,
_request: Request<IsModelReadyRequest>,
) -> Result<Response<IsModelReadyResponse>, Status> {
Err(Status::unimplemented("not needed"))
}
}
/// Start a mock Model Gateway server and return its address.
pub async fn start_mock_gateway(gateway: MockGateway) -> SocketAddr {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(ModelGatewayServiceServer::new(gateway))
.serve_with_incoming(incoming)
.await
.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
addr
}
}
fn valid_ctx() -> SessionContext {
SessionContext {
session_id: "sess-1".into(),
@@ -649,7 +756,6 @@ mod tests {
.expect("DB setup failed");
// Create service with mock embedding client via a real mock gateway server.
// (MockEmbeddingGenerator can't be used directly since the field expects EmbeddingClient.)
let db_arc = Arc::new(db);
let mut svc = MemoryServiceImpl::new(
db_arc,
@@ -658,75 +764,9 @@ mod tests {
ExtractionConfig::default(),
CacheConfig::default(),
);
// Set embedding client to mock by storing it as Arc<Mutex<dyn EmbeddingGenerator>>
// Since the field type is Option<Arc<Mutex<EmbeddingClient>>>, we need a different approach.
// Instead, set up a real mock gateway server.
// Use the mock gateway server approach from embedding tests
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
ModelGatewayService, ModelGatewayServiceServer,
};
use llm_multiverse_proto::llm_multiverse::v1::{
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
StreamInferenceRequest, StreamInferenceResponse,
};
struct TestGateway;
#[tonic::async_trait]
impl ModelGatewayService for TestGateway {
type StreamInferenceStream =
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
async fn stream_inference(
&self,
_request: Request<StreamInferenceRequest>,
) -> Result<Response<Self::StreamInferenceStream>, Status> {
Err(Status::unimplemented("not needed"))
}
async fn inference(
&self,
_request: Request<InferenceRequest>,
) -> Result<Response<InferenceResponse>, Status> {
Err(Status::unimplemented("not needed"))
}
async fn generate_embedding(
&self,
_request: Request<GenerateEmbeddingRequest>,
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
// Return a vector with value 0.10 matching test data
let embedding = vec![0.10_f32; EMBEDDING_DIM];
Ok(Response::new(GenerateEmbeddingResponse {
embedding,
dimensions: EMBEDDING_DIM as u32,
}))
}
async fn is_model_ready(
&self,
_request: Request<IsModelReadyRequest>,
) -> Result<Response<IsModelReadyResponse>, Status> {
Err(Status::unimplemented("not needed"))
}
}
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind failed");
let addr = listener.local_addr().expect("no local addr");
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(ModelGatewayServiceServer::new(TestGateway))
.serve_with_incoming(incoming)
.await
.expect("server failed");
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let addr =
test_helpers::start_mock_gateway(test_helpers::MockGateway::embedding_only()).await;
let embedding_client = crate::embedding::EmbeddingClient::connect(&format!("http://{addr}"))
.await
.expect("connect failed");
@@ -982,81 +1022,12 @@ mod tests {
assert_eq!(prov.trust_level, crate::provenance::TrustLevel::Revoked);
}
/// Helper to set up a mock gateway server that handles both embedding and inference.
/// Returns the gateway address.
mod extraction_test_helpers {
use crate::db::schema::EMBEDDING_DIM;
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
ModelGatewayService, ModelGatewayServiceServer,
};
use llm_multiverse_proto::llm_multiverse::v1::{
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
StreamInferenceRequest, StreamInferenceResponse,
};
use std::net::SocketAddr;
use tonic::{Request, Response, Status};
pub struct FullMockGateway;
#[tonic::async_trait]
impl ModelGatewayService for FullMockGateway {
type StreamInferenceStream =
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
async fn stream_inference(
&self,
_request: Request<StreamInferenceRequest>,
) -> Result<Response<Self::StreamInferenceStream>, Status> {
Err(Status::unimplemented("not needed"))
}
async fn inference(
&self,
_request: Request<InferenceRequest>,
) -> Result<Response<InferenceResponse>, Status> {
Ok(Response::new(InferenceResponse {
text: "SEGMENT: Extracted relevant text.\nCONFIDENCE: 0.85".to_string(),
finish_reason: "stop".to_string(),
tokens_used: 30,
}))
}
async fn generate_embedding(
&self,
_request: Request<GenerateEmbeddingRequest>,
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
let embedding = vec![0.10_f32; EMBEDDING_DIM];
Ok(Response::new(GenerateEmbeddingResponse {
embedding,
dimensions: EMBEDDING_DIM as u32,
}))
}
async fn is_model_ready(
&self,
_request: Request<IsModelReadyRequest>,
) -> Result<Response<IsModelReadyResponse>, Status> {
Err(Status::unimplemented("not needed"))
}
}
pub async fn start_full_mock_gateway() -> SocketAddr {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(ModelGatewayServiceServer::new(FullMockGateway))
.serve_with_incoming(incoming)
.await
.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
addr
}
/// Start a full mock gateway (embedding + inference) for extraction tests.
async fn start_full_mock_gateway() -> std::net::SocketAddr {
test_helpers::start_mock_gateway(test_helpers::MockGateway::full(
"SEGMENT: Extracted relevant text.\nCONFIDENCE: 0.85",
))
.await
}
/// Helper to populate DB with test data for extraction tests.
@@ -1101,7 +1072,7 @@ mod tests {
populate_test_db(&db);
let db_arc = Arc::new(db);
let addr = extraction_test_helpers::start_full_mock_gateway().await;
let addr = start_full_mock_gateway().await;
let endpoint = format!("http://{addr}");
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
@@ -1168,7 +1139,7 @@ mod tests {
populate_test_db(&db);
let db_arc = Arc::new(db);
let addr = extraction_test_helpers::start_full_mock_gateway().await;
let addr = start_full_mock_gateway().await;
let endpoint = format!("http://{addr}");
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
@@ -1230,7 +1201,7 @@ mod tests {
populate_test_db(&db);
let db_arc = Arc::new(db);
let addr = extraction_test_helpers::start_full_mock_gateway().await;
let addr = start_full_mock_gateway().await;
let endpoint = format!("http://{addr}");
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
@@ -1293,7 +1264,7 @@ mod tests {
populate_test_db(&db);
let db_arc = Arc::new(db);
let addr = extraction_test_helpers::start_full_mock_gateway().await;
let addr = start_full_mock_gateway().await;
let endpoint = format!("http://{addr}");
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
@@ -1344,7 +1315,7 @@ mod tests {
populate_test_db(&db);
let db_arc = Arc::new(db);
let addr = extraction_test_helpers::start_full_mock_gateway().await;
let addr = start_full_mock_gateway().await;
let endpoint = format!("http://{addr}");
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
@@ -1413,7 +1384,7 @@ mod tests {
populate_test_db(&db);
let db_arc = Arc::new(db);
let addr = extraction_test_helpers::start_full_mock_gateway().await;
let addr = start_full_mock_gateway().await;
let endpoint = format!("http://{addr}");
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
@@ -1474,7 +1445,7 @@ mod tests {
populate_test_db(&db);
let db_arc = Arc::new(db);
let addr = extraction_test_helpers::start_full_mock_gateway().await;
let addr = start_full_mock_gateway().await;
let endpoint = format!("http://{addr}");
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)