Merge pull request 'fix: resolve tech debt from issue #31 review (#120)' (#130) from feature/issue-120-tech-debt-31-review into main
This commit was merged in pull request #130.
This commit is contained in:
40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -680,6 +680,21 @@ version = "2.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
@@ -696,12 +711,34 @@ version = "0.3.32"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-executor"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-io"
|
name = "futures-io"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.117",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
@@ -720,8 +757,10 @@ version = "0.3.32"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-io",
|
"futures-io",
|
||||||
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
"memchr",
|
"memchr",
|
||||||
@@ -1367,6 +1406,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"chrono",
|
"chrono",
|
||||||
"duckdb",
|
"duckdb",
|
||||||
|
"futures",
|
||||||
"llm-multiverse-proto",
|
"llm-multiverse-proto",
|
||||||
"prost",
|
"prost",
|
||||||
"prost-types",
|
"prost-types",
|
||||||
|
|||||||
@@ -34,6 +34,7 @@
|
|||||||
| #114 | Tech debt: minor findings from issue #28 review | Phase 4 | `COMPLETED` | Rust | [issue-114.md](issue-114.md) |
|
| #114 | Tech debt: minor findings from issue #28 review | Phase 4 | `COMPLETED` | Rust | [issue-114.md](issue-114.md) |
|
||||||
| #116 | Tech debt: minor findings from issue #29 review | Phase 4 | `COMPLETED` | Rust | [issue-116.md](issue-116.md) |
|
| #116 | Tech debt: minor findings from issue #29 review | Phase 4 | `COMPLETED` | Rust | [issue-116.md](issue-116.md) |
|
||||||
| #118 | Tech debt: minor findings from issue #30 review | Phase 4 | `COMPLETED` | Rust | [issue-118.md](issue-118.md) |
|
| #118 | Tech debt: minor findings from issue #30 review | Phase 4 | `COMPLETED` | Rust | [issue-118.md](issue-118.md) |
|
||||||
|
| #120 | Tech debt: minor findings from issue #31 review | Phase 4 | `COMPLETED` | Rust | [issue-120.md](issue-120.md) |
|
||||||
|
|
||||||
## Status Legend
|
## Status Legend
|
||||||
|
|
||||||
|
|||||||
43
implementation-plans/issue-120.md
Normal file
43
implementation-plans/issue-120.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Issue #120: Tech debt: minor findings from issue #31 review
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
Address three tech debt items from the extraction step review:
|
||||||
|
|
||||||
|
1. **Add [UNTRUSTED CONTENT] marker for external-provenance memories** - Wrap corpus text in `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers when the memory has external provenance (provenance == 2), providing prompt injection mitigation.
|
||||||
|
|
||||||
|
2. **Parallel extraction for batch candidates** - Change `extract_batch` from sequential to parallel processing using `futures::future::join_all`, reducing latency for multiple candidates.
|
||||||
|
|
||||||
|
3. **Consolidate duplicate mock gateway test helpers** - Merge `FullMockGateway` (service.rs) and `TestGateway` (service.rs query_memory test) into a shared `test_helpers` module.
|
||||||
|
|
||||||
|
## Item 1: [UNTRUSTED CONTENT] markers
|
||||||
|
|
||||||
|
### Approach
|
||||||
|
|
||||||
|
Add a `provenance` parameter to `build_extraction_prompt`. When `provenance == 2` (external), wrap the corpus block in `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers. Update `extract` and `extract_batch` to pass the candidate's provenance field through.
|
||||||
|
|
||||||
|
### Files changed
|
||||||
|
|
||||||
|
- `services/memory/src/extraction/prompt.rs` - Add provenance param to `build_extraction_prompt`
|
||||||
|
- `services/memory/src/extraction/mod.rs` - Pass provenance through `extract` and `extract_batch`
|
||||||
|
|
||||||
|
## Item 2: Parallel extraction
|
||||||
|
|
||||||
|
### Approach
|
||||||
|
|
||||||
|
Replace the sequential `for` loop in `extract_batch` with `futures::future::join_all` to process all candidates concurrently. The `ExtractionClient` wraps a `ModelGatewayServiceClient` which supports concurrent calls via `clone()`.
|
||||||
|
|
||||||
|
### Files changed
|
||||||
|
|
||||||
|
- `services/memory/Cargo.toml` - Add `futures` dependency
|
||||||
|
- `services/memory/src/extraction/mod.rs` - Rewrite `extract_batch` with `join_all`
|
||||||
|
|
||||||
|
## Item 3: Consolidate mock gateway helpers
|
||||||
|
|
||||||
|
### Approach
|
||||||
|
|
||||||
|
Create a `test_helpers` module at the top of the `#[cfg(test)]` section in `service.rs` containing a single configurable `MockGateway` struct that handles both embedding and inference RPCs. Replace `FullMockGateway` and the inline `TestGateway` with this unified type.
|
||||||
|
|
||||||
|
### Files changed
|
||||||
|
|
||||||
|
- `services/memory/src/service.rs` - Consolidate into shared `test_helpers` module, update test functions
|
||||||
@@ -21,6 +21,7 @@ tokio-stream = "0.1"
|
|||||||
duckdb = { version = "1", features = ["bundled"] }
|
duckdb = { version = "1", features = ["bundled"] }
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
regex = "1"
|
regex = "1"
|
||||||
|
futures = "0.3"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ impl ExtractionClient {
|
|||||||
/// 4. Include a confidence score (0.0-1.0)
|
/// 4. Include a confidence score (0.0-1.0)
|
||||||
///
|
///
|
||||||
/// Uses `TaskComplexity::Simple` for lightweight model routing.
|
/// Uses `TaskComplexity::Simple` for lightweight model routing.
|
||||||
|
///
|
||||||
|
/// The `provenance` parameter controls whether the corpus is wrapped in
|
||||||
|
/// `[UNTRUSTED CONTENT]` markers (when `provenance == 2`, i.e. external).
|
||||||
pub async fn extract(
|
pub async fn extract(
|
||||||
&self,
|
&self,
|
||||||
context: &SessionContext,
|
context: &SessionContext,
|
||||||
@@ -74,6 +77,7 @@ impl ExtractionClient {
|
|||||||
corpus: &str,
|
corpus: &str,
|
||||||
memory_name: &str,
|
memory_name: &str,
|
||||||
memory_id: &str,
|
memory_id: &str,
|
||||||
|
provenance: i32,
|
||||||
) -> Result<ExtractionResult, ExtractionError> {
|
) -> Result<ExtractionResult, ExtractionError> {
|
||||||
// Skip extraction for empty corpus
|
// Skip extraction for empty corpus
|
||||||
if corpus.is_empty() {
|
if corpus.is_empty() {
|
||||||
@@ -84,7 +88,8 @@ impl ExtractionClient {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let extraction_prompt = prompt::build_extraction_prompt(query, corpus, memory_name);
|
let extraction_prompt =
|
||||||
|
prompt::build_extraction_prompt(query, corpus, memory_name, provenance);
|
||||||
|
|
||||||
let request = InferenceRequest {
|
let request = InferenceRequest {
|
||||||
params: Some(InferenceParams {
|
params: Some(InferenceParams {
|
||||||
@@ -104,9 +109,10 @@ impl ExtractionClient {
|
|||||||
Ok(prompt::parse_extraction_response(&response.text, memory_id))
|
Ok(prompt::parse_extraction_response(&response.text, memory_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract segments for multiple candidates.
|
/// Extract segments for multiple candidates in parallel.
|
||||||
///
|
///
|
||||||
/// Processes candidates sequentially to avoid overloading the gateway.
|
/// Uses `futures::future::join_all` to process all candidates concurrently,
|
||||||
|
/// reducing total latency compared to sequential processing.
|
||||||
/// On failure for a single candidate, that candidate's extraction is
|
/// On failure for a single candidate, that candidate's extraction is
|
||||||
/// skipped (returns the full corpus as fallback) and processing continues.
|
/// skipped (returns the full corpus as fallback) and processing continues.
|
||||||
pub async fn extract_batch(
|
pub async fn extract_batch(
|
||||||
@@ -115,9 +121,9 @@ impl ExtractionClient {
|
|||||||
query: &str,
|
query: &str,
|
||||||
candidates: &[RetrievalCandidate],
|
candidates: &[RetrievalCandidate],
|
||||||
) -> Vec<ExtractionResult> {
|
) -> Vec<ExtractionResult> {
|
||||||
let mut results = Vec::with_capacity(candidates.len());
|
let futures: Vec<_> = candidates
|
||||||
|
.iter()
|
||||||
for candidate in candidates {
|
.map(|candidate| async {
|
||||||
match self
|
match self
|
||||||
.extract(
|
.extract(
|
||||||
context,
|
context,
|
||||||
@@ -125,26 +131,28 @@ impl ExtractionClient {
|
|||||||
&candidate.corpus,
|
&candidate.corpus,
|
||||||
&candidate.name,
|
&candidate.name,
|
||||||
&candidate.memory_id,
|
&candidate.memory_id,
|
||||||
|
candidate.provenance,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(result) => results.push(result),
|
Ok(result) => result,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
memory_id = %candidate.memory_id,
|
memory_id = %candidate.memory_id,
|
||||||
error = %e,
|
error = %e,
|
||||||
"Extraction failed for candidate, using full corpus as fallback"
|
"Extraction failed for candidate, using full corpus as fallback"
|
||||||
);
|
);
|
||||||
results.push(ExtractionResult {
|
ExtractionResult {
|
||||||
memory_id: candidate.memory_id.clone(),
|
memory_id: candidate.memory_id.clone(),
|
||||||
segment: candidate.corpus.clone(),
|
segment: candidate.corpus.clone(),
|
||||||
confidence: 0.0,
|
confidence: 0.0,
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
results
|
futures::future::join_all(futures).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +310,7 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let result = client
|
let result = client
|
||||||
.extract(&ctx, "test query", "full corpus text here", "mem name", "mem-1")
|
.extract(&ctx, "test query", "full corpus text here", "mem name", "mem-1", 1)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@@ -329,7 +337,7 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let result = client
|
let result = client
|
||||||
.extract(&ctx, "test query", "corpus", "name", "mem-1")
|
.extract(&ctx, "test query", "corpus", "name", "mem-1", 1)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
@@ -359,7 +367,7 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let result = client
|
let result = client
|
||||||
.extract(&ctx, "test query", "", "mem name", "mem-1")
|
.extract(&ctx, "test query", "", "mem name", "mem-1", 1)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,22 @@ pub struct ExtractionResult {
|
|||||||
/// The prompt instructs the model to extract only the relevant segment
|
/// The prompt instructs the model to extract only the relevant segment
|
||||||
/// from the corpus, given the query context. Output format is structured
|
/// from the corpus, given the query context. Output format is structured
|
||||||
/// to enable reliable parsing.
|
/// to enable reliable parsing.
|
||||||
pub fn build_extraction_prompt(query: &str, corpus: &str, memory_name: &str) -> String {
|
///
|
||||||
|
/// When `provenance` is `2` (external), the corpus is wrapped in
|
||||||
|
/// `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers to mitigate
|
||||||
|
/// prompt injection from external-provenance memories.
|
||||||
|
pub fn build_extraction_prompt(
|
||||||
|
query: &str,
|
||||||
|
corpus: &str,
|
||||||
|
memory_name: &str,
|
||||||
|
provenance: i32,
|
||||||
|
) -> String {
|
||||||
|
let corpus_block = if provenance == 2 {
|
||||||
|
format!("[UNTRUSTED CONTENT]\n{corpus}\n[END UNTRUSTED CONTENT]")
|
||||||
|
} else {
|
||||||
|
corpus.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
format!(
|
format!(
|
||||||
"Given the following search query and memory content, extract ONLY the segment \
|
"Given the following search query and memory content, extract ONLY the segment \
|
||||||
of the content that is most relevant to the query. Return a concise, focused \
|
of the content that is most relevant to the query. Return a concise, focused \
|
||||||
@@ -31,7 +46,7 @@ Memory name: {memory_name}\n\
|
|||||||
\n\
|
\n\
|
||||||
Content:\n\
|
Content:\n\
|
||||||
---\n\
|
---\n\
|
||||||
{corpus}\n\
|
{corpus_block}\n\
|
||||||
---\n\
|
---\n\
|
||||||
\n\
|
\n\
|
||||||
Respond in this exact format:\n\
|
Respond in this exact format:\n\
|
||||||
@@ -99,22 +114,43 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_extraction_prompt_contains_query() {
|
fn test_build_extraction_prompt_contains_query() {
|
||||||
let prompt = build_extraction_prompt("how to sort arrays", "some corpus", "mem name");
|
let prompt = build_extraction_prompt("how to sort arrays", "some corpus", "mem name", 1);
|
||||||
assert!(prompt.contains("how to sort arrays"));
|
assert!(prompt.contains("how to sort arrays"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_extraction_prompt_contains_corpus() {
|
fn test_build_extraction_prompt_contains_corpus() {
|
||||||
let prompt = build_extraction_prompt("query", "the corpus content here", "mem name");
|
let prompt = build_extraction_prompt("query", "the corpus content here", "mem name", 1);
|
||||||
assert!(prompt.contains("the corpus content here"));
|
assert!(prompt.contains("the corpus content here"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_extraction_prompt_contains_memory_name() {
|
fn test_build_extraction_prompt_contains_memory_name() {
|
||||||
let prompt = build_extraction_prompt("query", "corpus", "Rust Sort Algorithm");
|
let prompt = build_extraction_prompt("query", "corpus", "Rust Sort Algorithm", 1);
|
||||||
assert!(prompt.contains("Rust Sort Algorithm"));
|
assert!(prompt.contains("Rust Sort Algorithm"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_extraction_prompt_internal_no_untrusted_marker() {
|
||||||
|
let prompt = build_extraction_prompt("query", "corpus text", "mem", 1);
|
||||||
|
assert!(!prompt.contains("[UNTRUSTED CONTENT]"));
|
||||||
|
assert!(!prompt.contains("[END UNTRUSTED CONTENT]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_extraction_prompt_external_has_untrusted_marker() {
|
||||||
|
let prompt = build_extraction_prompt("query", "external corpus", "mem", 2);
|
||||||
|
assert!(prompt.contains("[UNTRUSTED CONTENT]"));
|
||||||
|
assert!(prompt.contains("[END UNTRUSTED CONTENT]"));
|
||||||
|
assert!(prompt.contains("external corpus"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_extraction_prompt_unspecified_no_untrusted_marker() {
|
||||||
|
let prompt = build_extraction_prompt("query", "corpus text", "mem", 0);
|
||||||
|
assert!(!prompt.contains("[UNTRUSTED CONTENT]"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_extraction_response_valid() {
|
fn test_parse_extraction_response_valid() {
|
||||||
let response = "SEGMENT: The quicksort algorithm divides the array.\nCONFIDENCE: 0.8";
|
let response = "SEGMENT: The quicksort algorithm divides the array.\nCONFIDENCE: 0.8";
|
||||||
|
|||||||
@@ -463,6 +463,113 @@ mod tests {
|
|||||||
use crate::config::{CacheConfig, ExtractionConfig, ProvenanceConfig, RetrievalConfig};
|
use crate::config::{CacheConfig, ExtractionConfig, ProvenanceConfig, RetrievalConfig};
|
||||||
use llm_multiverse_proto::llm_multiverse::v1::{MemoryEntry, SessionContext};
|
use llm_multiverse_proto::llm_multiverse::v1::{MemoryEntry, SessionContext};
|
||||||
|
|
||||||
|
/// Shared mock Model Gateway for tests that need embedding and/or inference RPCs.
|
||||||
|
///
|
||||||
|
/// Consolidates `FullMockGateway` and the inline `TestGateway` that were
|
||||||
|
/// previously duplicated across test functions.
|
||||||
|
mod test_helpers {
|
||||||
|
use crate::db::schema::EMBEDDING_DIM;
|
||||||
|
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
|
||||||
|
ModelGatewayService, ModelGatewayServiceServer,
|
||||||
|
};
|
||||||
|
use llm_multiverse_proto::llm_multiverse::v1::{
|
||||||
|
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
|
||||||
|
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
|
||||||
|
StreamInferenceRequest, StreamInferenceResponse,
|
||||||
|
};
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use tonic::{Request, Response, Status};
|
||||||
|
|
||||||
|
/// A configurable mock gateway that handles embedding and inference RPCs.
|
||||||
|
///
|
||||||
|
/// - `inference_response`: If `Some`, inference calls return this text.
|
||||||
|
/// If `None`, inference calls return `Unimplemented`.
|
||||||
|
/// - Embedding calls always return a vector of `0.10` values matching
|
||||||
|
/// `EMBEDDING_DIM`.
|
||||||
|
pub struct MockGateway {
|
||||||
|
pub inference_response: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockGateway {
|
||||||
|
/// Create a gateway that supports embeddings but not inference.
|
||||||
|
pub fn embedding_only() -> Self {
|
||||||
|
Self {
|
||||||
|
inference_response: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a gateway that supports both embeddings and inference.
|
||||||
|
pub fn full(inference_text: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
inference_response: Some(inference_text.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tonic::async_trait]
|
||||||
|
impl ModelGatewayService for MockGateway {
|
||||||
|
type StreamInferenceStream =
|
||||||
|
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
|
||||||
|
|
||||||
|
async fn stream_inference(
|
||||||
|
&self,
|
||||||
|
_request: Request<StreamInferenceRequest>,
|
||||||
|
) -> Result<Response<Self::StreamInferenceStream>, Status> {
|
||||||
|
Err(Status::unimplemented("not needed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn inference(
|
||||||
|
&self,
|
||||||
|
_request: Request<InferenceRequest>,
|
||||||
|
) -> Result<Response<InferenceResponse>, Status> {
|
||||||
|
match &self.inference_response {
|
||||||
|
Some(text) => Ok(Response::new(InferenceResponse {
|
||||||
|
text: text.clone(),
|
||||||
|
finish_reason: "stop".to_string(),
|
||||||
|
tokens_used: 30,
|
||||||
|
})),
|
||||||
|
None => Err(Status::unimplemented("not needed")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_embedding(
|
||||||
|
&self,
|
||||||
|
_request: Request<GenerateEmbeddingRequest>,
|
||||||
|
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
|
||||||
|
let embedding = vec![0.10_f32; EMBEDDING_DIM];
|
||||||
|
Ok(Response::new(GenerateEmbeddingResponse {
|
||||||
|
embedding,
|
||||||
|
dimensions: EMBEDDING_DIM as u32,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn is_model_ready(
|
||||||
|
&self,
|
||||||
|
_request: Request<IsModelReadyRequest>,
|
||||||
|
) -> Result<Response<IsModelReadyResponse>, Status> {
|
||||||
|
Err(Status::unimplemented("not needed"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start a mock Model Gateway server and return its address.
|
||||||
|
pub async fn start_mock_gateway(gateway: MockGateway) -> SocketAddr {
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tonic::transport::Server::builder()
|
||||||
|
.add_service(ModelGatewayServiceServer::new(gateway))
|
||||||
|
.serve_with_incoming(incoming)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||||
|
addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn valid_ctx() -> SessionContext {
|
fn valid_ctx() -> SessionContext {
|
||||||
SessionContext {
|
SessionContext {
|
||||||
session_id: "sess-1".into(),
|
session_id: "sess-1".into(),
|
||||||
@@ -649,7 +756,6 @@ mod tests {
|
|||||||
.expect("DB setup failed");
|
.expect("DB setup failed");
|
||||||
|
|
||||||
// Create service with mock embedding client via a real mock gateway server.
|
// Create service with mock embedding client via a real mock gateway server.
|
||||||
// (MockEmbeddingGenerator can't be used directly since the field expects EmbeddingClient.)
|
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
let mut svc = MemoryServiceImpl::new(
|
let mut svc = MemoryServiceImpl::new(
|
||||||
db_arc,
|
db_arc,
|
||||||
@@ -658,75 +764,9 @@ mod tests {
|
|||||||
ExtractionConfig::default(),
|
ExtractionConfig::default(),
|
||||||
CacheConfig::default(),
|
CacheConfig::default(),
|
||||||
);
|
);
|
||||||
// Set embedding client to mock by storing it as Arc<Mutex<dyn EmbeddingGenerator>>
|
|
||||||
// Since the field type is Option<Arc<Mutex<EmbeddingClient>>>, we need a different approach.
|
|
||||||
// Instead, set up a real mock gateway server.
|
|
||||||
|
|
||||||
// Use the mock gateway server approach from embedding tests
|
|
||||||
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
|
|
||||||
ModelGatewayService, ModelGatewayServiceServer,
|
|
||||||
};
|
|
||||||
use llm_multiverse_proto::llm_multiverse::v1::{
|
|
||||||
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
|
|
||||||
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
|
|
||||||
StreamInferenceRequest, StreamInferenceResponse,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TestGateway;
|
|
||||||
|
|
||||||
#[tonic::async_trait]
|
|
||||||
impl ModelGatewayService for TestGateway {
|
|
||||||
type StreamInferenceStream =
|
|
||||||
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
|
|
||||||
|
|
||||||
async fn stream_inference(
|
|
||||||
&self,
|
|
||||||
_request: Request<StreamInferenceRequest>,
|
|
||||||
) -> Result<Response<Self::StreamInferenceStream>, Status> {
|
|
||||||
Err(Status::unimplemented("not needed"))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn inference(
|
|
||||||
&self,
|
|
||||||
_request: Request<InferenceRequest>,
|
|
||||||
) -> Result<Response<InferenceResponse>, Status> {
|
|
||||||
Err(Status::unimplemented("not needed"))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_embedding(
|
|
||||||
&self,
|
|
||||||
_request: Request<GenerateEmbeddingRequest>,
|
|
||||||
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
|
|
||||||
// Return a vector with value 0.10 matching test data
|
|
||||||
let embedding = vec![0.10_f32; EMBEDDING_DIM];
|
|
||||||
Ok(Response::new(GenerateEmbeddingResponse {
|
|
||||||
embedding,
|
|
||||||
dimensions: EMBEDDING_DIM as u32,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn is_model_ready(
|
|
||||||
&self,
|
|
||||||
_request: Request<IsModelReadyRequest>,
|
|
||||||
) -> Result<Response<IsModelReadyResponse>, Status> {
|
|
||||||
Err(Status::unimplemented("not needed"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
|
||||||
.await
|
|
||||||
.expect("bind failed");
|
|
||||||
let addr = listener.local_addr().expect("no local addr");
|
|
||||||
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tonic::transport::Server::builder()
|
|
||||||
.add_service(ModelGatewayServiceServer::new(TestGateway))
|
|
||||||
.serve_with_incoming(incoming)
|
|
||||||
.await
|
|
||||||
.expect("server failed");
|
|
||||||
});
|
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
||||||
|
|
||||||
|
let addr =
|
||||||
|
test_helpers::start_mock_gateway(test_helpers::MockGateway::embedding_only()).await;
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&format!("http://{addr}"))
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&format!("http://{addr}"))
|
||||||
.await
|
.await
|
||||||
.expect("connect failed");
|
.expect("connect failed");
|
||||||
@@ -982,81 +1022,12 @@ mod tests {
|
|||||||
assert_eq!(prov.trust_level, crate::provenance::TrustLevel::Revoked);
|
assert_eq!(prov.trust_level, crate::provenance::TrustLevel::Revoked);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper to set up a mock gateway server that handles both embedding and inference.
|
/// Start a full mock gateway (embedding + inference) for extraction tests.
|
||||||
/// Returns the gateway address.
|
async fn start_full_mock_gateway() -> std::net::SocketAddr {
|
||||||
mod extraction_test_helpers {
|
test_helpers::start_mock_gateway(test_helpers::MockGateway::full(
|
||||||
use crate::db::schema::EMBEDDING_DIM;
|
"SEGMENT: Extracted relevant text.\nCONFIDENCE: 0.85",
|
||||||
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
|
))
|
||||||
ModelGatewayService, ModelGatewayServiceServer,
|
|
||||||
};
|
|
||||||
use llm_multiverse_proto::llm_multiverse::v1::{
|
|
||||||
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
|
|
||||||
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
|
|
||||||
StreamInferenceRequest, StreamInferenceResponse,
|
|
||||||
};
|
|
||||||
use std::net::SocketAddr;
|
|
||||||
use tonic::{Request, Response, Status};
|
|
||||||
|
|
||||||
pub struct FullMockGateway;
|
|
||||||
|
|
||||||
#[tonic::async_trait]
|
|
||||||
impl ModelGatewayService for FullMockGateway {
|
|
||||||
type StreamInferenceStream =
|
|
||||||
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
|
|
||||||
|
|
||||||
async fn stream_inference(
|
|
||||||
&self,
|
|
||||||
_request: Request<StreamInferenceRequest>,
|
|
||||||
) -> Result<Response<Self::StreamInferenceStream>, Status> {
|
|
||||||
Err(Status::unimplemented("not needed"))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn inference(
|
|
||||||
&self,
|
|
||||||
_request: Request<InferenceRequest>,
|
|
||||||
) -> Result<Response<InferenceResponse>, Status> {
|
|
||||||
Ok(Response::new(InferenceResponse {
|
|
||||||
text: "SEGMENT: Extracted relevant text.\nCONFIDENCE: 0.85".to_string(),
|
|
||||||
finish_reason: "stop".to_string(),
|
|
||||||
tokens_used: 30,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_embedding(
|
|
||||||
&self,
|
|
||||||
_request: Request<GenerateEmbeddingRequest>,
|
|
||||||
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
|
|
||||||
let embedding = vec![0.10_f32; EMBEDDING_DIM];
|
|
||||||
Ok(Response::new(GenerateEmbeddingResponse {
|
|
||||||
embedding,
|
|
||||||
dimensions: EMBEDDING_DIM as u32,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn is_model_ready(
|
|
||||||
&self,
|
|
||||||
_request: Request<IsModelReadyRequest>,
|
|
||||||
) -> Result<Response<IsModelReadyResponse>, Status> {
|
|
||||||
Err(Status::unimplemented("not needed"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start_full_mock_gateway() -> SocketAddr {
|
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tonic::transport::Server::builder()
|
|
||||||
.add_service(ModelGatewayServiceServer::new(FullMockGateway))
|
|
||||||
.serve_with_incoming(incoming)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
});
|
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
|
||||||
addr
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper to populate DB with test data for extraction tests.
|
/// Helper to populate DB with test data for extraction tests.
|
||||||
@@ -1101,7 +1072,7 @@ mod tests {
|
|||||||
populate_test_db(&db);
|
populate_test_db(&db);
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
|
|
||||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
let addr = start_full_mock_gateway().await;
|
||||||
let endpoint = format!("http://{addr}");
|
let endpoint = format!("http://{addr}");
|
||||||
|
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||||
@@ -1168,7 +1139,7 @@ mod tests {
|
|||||||
populate_test_db(&db);
|
populate_test_db(&db);
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
|
|
||||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
let addr = start_full_mock_gateway().await;
|
||||||
let endpoint = format!("http://{addr}");
|
let endpoint = format!("http://{addr}");
|
||||||
|
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||||
@@ -1230,7 +1201,7 @@ mod tests {
|
|||||||
populate_test_db(&db);
|
populate_test_db(&db);
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
|
|
||||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
let addr = start_full_mock_gateway().await;
|
||||||
let endpoint = format!("http://{addr}");
|
let endpoint = format!("http://{addr}");
|
||||||
|
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||||
@@ -1293,7 +1264,7 @@ mod tests {
|
|||||||
populate_test_db(&db);
|
populate_test_db(&db);
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
|
|
||||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
let addr = start_full_mock_gateway().await;
|
||||||
let endpoint = format!("http://{addr}");
|
let endpoint = format!("http://{addr}");
|
||||||
|
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||||
@@ -1344,7 +1315,7 @@ mod tests {
|
|||||||
populate_test_db(&db);
|
populate_test_db(&db);
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
|
|
||||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
let addr = start_full_mock_gateway().await;
|
||||||
let endpoint = format!("http://{addr}");
|
let endpoint = format!("http://{addr}");
|
||||||
|
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||||
@@ -1413,7 +1384,7 @@ mod tests {
|
|||||||
populate_test_db(&db);
|
populate_test_db(&db);
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
|
|
||||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
let addr = start_full_mock_gateway().await;
|
||||||
let endpoint = format!("http://{addr}");
|
let endpoint = format!("http://{addr}");
|
||||||
|
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||||
@@ -1474,7 +1445,7 @@ mod tests {
|
|||||||
populate_test_db(&db);
|
populate_test_db(&db);
|
||||||
let db_arc = Arc::new(db);
|
let db_arc = Arc::new(db);
|
||||||
|
|
||||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
let addr = start_full_mock_gateway().await;
|
||||||
let endpoint = format!("http://{addr}");
|
let endpoint = format!("http://{addr}");
|
||||||
|
|
||||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||||
|
|||||||
Reference in New Issue
Block a user