40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -680,6 +680,21 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.32"
|
||||
@@ -696,12 +711,34 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.32"
|
||||
@@ -720,8 +757,10 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
@@ -1367,6 +1406,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"duckdb",
|
||||
"futures",
|
||||
"llm-multiverse-proto",
|
||||
"prost",
|
||||
"prost-types",
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
| #114 | Tech debt: minor findings from issue #28 review | Phase 4 | `COMPLETED` | Rust | [issue-114.md](issue-114.md) |
|
||||
| #116 | Tech debt: minor findings from issue #29 review | Phase 4 | `COMPLETED` | Rust | [issue-116.md](issue-116.md) |
|
||||
| #118 | Tech debt: minor findings from issue #30 review | Phase 4 | `COMPLETED` | Rust | [issue-118.md](issue-118.md) |
|
||||
| #120 | Tech debt: minor findings from issue #31 review | Phase 4 | `COMPLETED` | Rust | [issue-120.md](issue-120.md) |
|
||||
|
||||
## Status Legend
|
||||
|
||||
|
||||
43
implementation-plans/issue-120.md
Normal file
43
implementation-plans/issue-120.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Issue #120: Tech debt: minor findings from issue #31 review
|
||||
|
||||
## Summary
|
||||
|
||||
Address three tech debt items from the extraction step review:
|
||||
|
||||
1. **Add [UNTRUSTED CONTENT] marker for external-provenance memories** - Wrap corpus text in `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers when the memory has external provenance (provenance == 2), providing prompt injection mitigation.
|
||||
|
||||
2. **Parallel extraction for batch candidates** - Change `extract_batch` from sequential to parallel processing using `futures::future::join_all`, reducing latency for multiple candidates.
|
||||
|
||||
3. **Consolidate duplicate mock gateway test helpers** - Merge `FullMockGateway` (service.rs) and `TestGateway` (service.rs query_memory test) into a shared `test_helpers` module.
|
||||
|
||||
## Item 1: [UNTRUSTED CONTENT] markers
|
||||
|
||||
### Approach
|
||||
|
||||
Add a `provenance` parameter to `build_extraction_prompt`. When `provenance == 2` (external), wrap the corpus block in `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers. Update `extract` and `extract_batch` to pass the candidate's provenance field through.
|
||||
|
||||
### Files changed
|
||||
|
||||
- `services/memory/src/extraction/prompt.rs` - Add provenance param to `build_extraction_prompt`
|
||||
- `services/memory/src/extraction/mod.rs` - Pass provenance through `extract` and `extract_batch`
|
||||
|
||||
## Item 2: Parallel extraction
|
||||
|
||||
### Approach
|
||||
|
||||
Replace the sequential `for` loop in `extract_batch` with `futures::future::join_all` to process all candidates concurrently. The `ExtractionClient` wraps a `ModelGatewayServiceClient` which supports concurrent calls via `clone()`.
|
||||
|
||||
### Files changed
|
||||
|
||||
- `services/memory/Cargo.toml` - Add `futures` dependency
|
||||
- `services/memory/src/extraction/mod.rs` - Rewrite `extract_batch` with `join_all`
|
||||
|
||||
## Item 3: Consolidate mock gateway helpers
|
||||
|
||||
### Approach
|
||||
|
||||
Create a `test_helpers` module at the top of the `#[cfg(test)]` section in `service.rs` containing a single configurable `MockGateway` struct that handles both embedding and inference RPCs. Replace `FullMockGateway` and the inline `TestGateway` with this unified type.
|
||||
|
||||
### Files changed
|
||||
|
||||
- `services/memory/src/service.rs` - Consolidate into shared `test_helpers` module, update test functions
|
||||
@@ -21,6 +21,7 @@ tokio-stream = "0.1"
|
||||
duckdb = { version = "1", features = ["bundled"] }
|
||||
chrono = "0.4"
|
||||
regex = "1"
|
||||
futures = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
@@ -67,6 +67,9 @@ impl ExtractionClient {
|
||||
/// 4. Include a confidence score (0.0-1.0)
|
||||
///
|
||||
/// Uses `TaskComplexity::Simple` for lightweight model routing.
|
||||
///
|
||||
/// The `provenance` parameter controls whether the corpus is wrapped in
|
||||
/// `[UNTRUSTED CONTENT]` markers (when `provenance == 2`, i.e. external).
|
||||
pub async fn extract(
|
||||
&self,
|
||||
context: &SessionContext,
|
||||
@@ -74,6 +77,7 @@ impl ExtractionClient {
|
||||
corpus: &str,
|
||||
memory_name: &str,
|
||||
memory_id: &str,
|
||||
provenance: i32,
|
||||
) -> Result<ExtractionResult, ExtractionError> {
|
||||
// Skip extraction for empty corpus
|
||||
if corpus.is_empty() {
|
||||
@@ -84,7 +88,8 @@ impl ExtractionClient {
|
||||
});
|
||||
}
|
||||
|
||||
let extraction_prompt = prompt::build_extraction_prompt(query, corpus, memory_name);
|
||||
let extraction_prompt =
|
||||
prompt::build_extraction_prompt(query, corpus, memory_name, provenance);
|
||||
|
||||
let request = InferenceRequest {
|
||||
params: Some(InferenceParams {
|
||||
@@ -104,9 +109,10 @@ impl ExtractionClient {
|
||||
Ok(prompt::parse_extraction_response(&response.text, memory_id))
|
||||
}
|
||||
|
||||
/// Extract segments for multiple candidates.
|
||||
/// Extract segments for multiple candidates in parallel.
|
||||
///
|
||||
/// Processes candidates sequentially to avoid overloading the gateway.
|
||||
/// Uses `futures::future::join_all` to process all candidates concurrently,
|
||||
/// reducing total latency compared to sequential processing.
|
||||
/// On failure for a single candidate, that candidate's extraction is
|
||||
/// skipped (returns the full corpus as fallback) and processing continues.
|
||||
pub async fn extract_batch(
|
||||
@@ -115,36 +121,38 @@ impl ExtractionClient {
|
||||
query: &str,
|
||||
candidates: &[RetrievalCandidate],
|
||||
) -> Vec<ExtractionResult> {
|
||||
let mut results = Vec::with_capacity(candidates.len());
|
||||
|
||||
for candidate in candidates {
|
||||
match self
|
||||
.extract(
|
||||
context,
|
||||
query,
|
||||
&candidate.corpus,
|
||||
&candidate.name,
|
||||
&candidate.memory_id,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => results.push(result),
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
memory_id = %candidate.memory_id,
|
||||
error = %e,
|
||||
"Extraction failed for candidate, using full corpus as fallback"
|
||||
);
|
||||
results.push(ExtractionResult {
|
||||
memory_id: candidate.memory_id.clone(),
|
||||
segment: candidate.corpus.clone(),
|
||||
confidence: 0.0,
|
||||
});
|
||||
let futures: Vec<_> = candidates
|
||||
.iter()
|
||||
.map(|candidate| async {
|
||||
match self
|
||||
.extract(
|
||||
context,
|
||||
query,
|
||||
&candidate.corpus,
|
||||
&candidate.name,
|
||||
&candidate.memory_id,
|
||||
candidate.provenance,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
memory_id = %candidate.memory_id,
|
||||
error = %e,
|
||||
"Extraction failed for candidate, using full corpus as fallback"
|
||||
);
|
||||
ExtractionResult {
|
||||
memory_id: candidate.memory_id.clone(),
|
||||
segment: candidate.corpus.clone(),
|
||||
confidence: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
results
|
||||
futures::future::join_all(futures).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,7 +310,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let result = client
|
||||
.extract(&ctx, "test query", "full corpus text here", "mem name", "mem-1")
|
||||
.extract(&ctx, "test query", "full corpus text here", "mem name", "mem-1", 1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -329,7 +337,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let result = client
|
||||
.extract(&ctx, "test query", "corpus", "name", "mem-1")
|
||||
.extract(&ctx, "test query", "corpus", "name", "mem-1", 1)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
@@ -359,7 +367,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let result = client
|
||||
.extract(&ctx, "test query", "", "mem name", "mem-1")
|
||||
.extract(&ctx, "test query", "", "mem name", "mem-1", 1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -20,7 +20,22 @@ pub struct ExtractionResult {
|
||||
/// The prompt instructs the model to extract only the relevant segment
|
||||
/// from the corpus, given the query context. Output format is structured
|
||||
/// to enable reliable parsing.
|
||||
pub fn build_extraction_prompt(query: &str, corpus: &str, memory_name: &str) -> String {
|
||||
///
|
||||
/// When `provenance` is `2` (external), the corpus is wrapped in
|
||||
/// `[UNTRUSTED CONTENT]...[END UNTRUSTED CONTENT]` markers to mitigate
|
||||
/// prompt injection from external-provenance memories.
|
||||
pub fn build_extraction_prompt(
|
||||
query: &str,
|
||||
corpus: &str,
|
||||
memory_name: &str,
|
||||
provenance: i32,
|
||||
) -> String {
|
||||
let corpus_block = if provenance == 2 {
|
||||
format!("[UNTRUSTED CONTENT]\n{corpus}\n[END UNTRUSTED CONTENT]")
|
||||
} else {
|
||||
corpus.to_string()
|
||||
};
|
||||
|
||||
format!(
|
||||
"Given the following search query and memory content, extract ONLY the segment \
|
||||
of the content that is most relevant to the query. Return a concise, focused \
|
||||
@@ -31,7 +46,7 @@ Memory name: {memory_name}\n\
|
||||
\n\
|
||||
Content:\n\
|
||||
---\n\
|
||||
{corpus}\n\
|
||||
{corpus_block}\n\
|
||||
---\n\
|
||||
\n\
|
||||
Respond in this exact format:\n\
|
||||
@@ -99,22 +114,43 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_build_extraction_prompt_contains_query() {
|
||||
let prompt = build_extraction_prompt("how to sort arrays", "some corpus", "mem name");
|
||||
let prompt = build_extraction_prompt("how to sort arrays", "some corpus", "mem name", 1);
|
||||
assert!(prompt.contains("how to sort arrays"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_extraction_prompt_contains_corpus() {
|
||||
let prompt = build_extraction_prompt("query", "the corpus content here", "mem name");
|
||||
let prompt = build_extraction_prompt("query", "the corpus content here", "mem name", 1);
|
||||
assert!(prompt.contains("the corpus content here"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_extraction_prompt_contains_memory_name() {
|
||||
let prompt = build_extraction_prompt("query", "corpus", "Rust Sort Algorithm");
|
||||
let prompt = build_extraction_prompt("query", "corpus", "Rust Sort Algorithm", 1);
|
||||
assert!(prompt.contains("Rust Sort Algorithm"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_extraction_prompt_internal_no_untrusted_marker() {
|
||||
let prompt = build_extraction_prompt("query", "corpus text", "mem", 1);
|
||||
assert!(!prompt.contains("[UNTRUSTED CONTENT]"));
|
||||
assert!(!prompt.contains("[END UNTRUSTED CONTENT]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_extraction_prompt_external_has_untrusted_marker() {
|
||||
let prompt = build_extraction_prompt("query", "external corpus", "mem", 2);
|
||||
assert!(prompt.contains("[UNTRUSTED CONTENT]"));
|
||||
assert!(prompt.contains("[END UNTRUSTED CONTENT]"));
|
||||
assert!(prompt.contains("external corpus"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_extraction_prompt_unspecified_no_untrusted_marker() {
|
||||
let prompt = build_extraction_prompt("query", "corpus text", "mem", 0);
|
||||
assert!(!prompt.contains("[UNTRUSTED CONTENT]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_valid() {
|
||||
let response = "SEGMENT: The quicksort algorithm divides the array.\nCONFIDENCE: 0.8";
|
||||
|
||||
@@ -463,6 +463,113 @@ mod tests {
|
||||
use crate::config::{CacheConfig, ExtractionConfig, ProvenanceConfig, RetrievalConfig};
|
||||
use llm_multiverse_proto::llm_multiverse::v1::{MemoryEntry, SessionContext};
|
||||
|
||||
/// Shared mock Model Gateway for tests that need embedding and/or inference RPCs.
|
||||
///
|
||||
/// Consolidates `FullMockGateway` and the inline `TestGateway` that were
|
||||
/// previously duplicated across test functions.
|
||||
mod test_helpers {
|
||||
use crate::db::schema::EMBEDDING_DIM;
|
||||
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
|
||||
ModelGatewayService, ModelGatewayServiceServer,
|
||||
};
|
||||
use llm_multiverse_proto::llm_multiverse::v1::{
|
||||
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
|
||||
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
|
||||
StreamInferenceRequest, StreamInferenceResponse,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
/// A configurable mock gateway that handles embedding and inference RPCs.
|
||||
///
|
||||
/// - `inference_response`: If `Some`, inference calls return this text.
|
||||
/// If `None`, inference calls return `Unimplemented`.
|
||||
/// - Embedding calls always return a vector of `0.10` values matching
|
||||
/// `EMBEDDING_DIM`.
|
||||
pub struct MockGateway {
|
||||
pub inference_response: Option<String>,
|
||||
}
|
||||
|
||||
impl MockGateway {
|
||||
/// Create a gateway that supports embeddings but not inference.
|
||||
pub fn embedding_only() -> Self {
|
||||
Self {
|
||||
inference_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a gateway that supports both embeddings and inference.
|
||||
pub fn full(inference_text: &str) -> Self {
|
||||
Self {
|
||||
inference_response: Some(inference_text.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl ModelGatewayService for MockGateway {
|
||||
type StreamInferenceStream =
|
||||
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
|
||||
|
||||
async fn stream_inference(
|
||||
&self,
|
||||
_request: Request<StreamInferenceRequest>,
|
||||
) -> Result<Response<Self::StreamInferenceStream>, Status> {
|
||||
Err(Status::unimplemented("not needed"))
|
||||
}
|
||||
|
||||
async fn inference(
|
||||
&self,
|
||||
_request: Request<InferenceRequest>,
|
||||
) -> Result<Response<InferenceResponse>, Status> {
|
||||
match &self.inference_response {
|
||||
Some(text) => Ok(Response::new(InferenceResponse {
|
||||
text: text.clone(),
|
||||
finish_reason: "stop".to_string(),
|
||||
tokens_used: 30,
|
||||
})),
|
||||
None => Err(Status::unimplemented("not needed")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn generate_embedding(
|
||||
&self,
|
||||
_request: Request<GenerateEmbeddingRequest>,
|
||||
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
|
||||
let embedding = vec![0.10_f32; EMBEDDING_DIM];
|
||||
Ok(Response::new(GenerateEmbeddingResponse {
|
||||
embedding,
|
||||
dimensions: EMBEDDING_DIM as u32,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn is_model_ready(
|
||||
&self,
|
||||
_request: Request<IsModelReadyRequest>,
|
||||
) -> Result<Response<IsModelReadyResponse>, Status> {
|
||||
Err(Status::unimplemented("not needed"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a mock Model Gateway server and return its address.
|
||||
pub async fn start_mock_gateway(gateway: MockGateway) -> SocketAddr {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||
tokio::spawn(async move {
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(ModelGatewayServiceServer::new(gateway))
|
||||
.serve_with_incoming(incoming)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
addr
|
||||
}
|
||||
}
|
||||
|
||||
fn valid_ctx() -> SessionContext {
|
||||
SessionContext {
|
||||
session_id: "sess-1".into(),
|
||||
@@ -649,7 +756,6 @@ mod tests {
|
||||
.expect("DB setup failed");
|
||||
|
||||
// Create service with mock embedding client via a real mock gateway server.
|
||||
// (MockEmbeddingGenerator can't be used directly since the field expects EmbeddingClient.)
|
||||
let db_arc = Arc::new(db);
|
||||
let mut svc = MemoryServiceImpl::new(
|
||||
db_arc,
|
||||
@@ -658,75 +764,9 @@ mod tests {
|
||||
ExtractionConfig::default(),
|
||||
CacheConfig::default(),
|
||||
);
|
||||
// Set embedding client to mock by storing it as Arc<Mutex<dyn EmbeddingGenerator>>
|
||||
// Since the field type is Option<Arc<Mutex<EmbeddingClient>>>, we need a different approach.
|
||||
// Instead, set up a real mock gateway server.
|
||||
|
||||
// Use the mock gateway server approach from embedding tests
|
||||
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
|
||||
ModelGatewayService, ModelGatewayServiceServer,
|
||||
};
|
||||
use llm_multiverse_proto::llm_multiverse::v1::{
|
||||
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
|
||||
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
|
||||
StreamInferenceRequest, StreamInferenceResponse,
|
||||
};
|
||||
|
||||
struct TestGateway;
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl ModelGatewayService for TestGateway {
|
||||
type StreamInferenceStream =
|
||||
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
|
||||
|
||||
async fn stream_inference(
|
||||
&self,
|
||||
_request: Request<StreamInferenceRequest>,
|
||||
) -> Result<Response<Self::StreamInferenceStream>, Status> {
|
||||
Err(Status::unimplemented("not needed"))
|
||||
}
|
||||
|
||||
async fn inference(
|
||||
&self,
|
||||
_request: Request<InferenceRequest>,
|
||||
) -> Result<Response<InferenceResponse>, Status> {
|
||||
Err(Status::unimplemented("not needed"))
|
||||
}
|
||||
|
||||
async fn generate_embedding(
|
||||
&self,
|
||||
_request: Request<GenerateEmbeddingRequest>,
|
||||
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
|
||||
// Return a vector with value 0.10 matching test data
|
||||
let embedding = vec![0.10_f32; EMBEDDING_DIM];
|
||||
Ok(Response::new(GenerateEmbeddingResponse {
|
||||
embedding,
|
||||
dimensions: EMBEDDING_DIM as u32,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn is_model_ready(
|
||||
&self,
|
||||
_request: Request<IsModelReadyRequest>,
|
||||
) -> Result<Response<IsModelReadyResponse>, Status> {
|
||||
Err(Status::unimplemented("not needed"))
|
||||
}
|
||||
}
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind failed");
|
||||
let addr = listener.local_addr().expect("no local addr");
|
||||
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||
tokio::spawn(async move {
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(ModelGatewayServiceServer::new(TestGateway))
|
||||
.serve_with_incoming(incoming)
|
||||
.await
|
||||
.expect("server failed");
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
let addr =
|
||||
test_helpers::start_mock_gateway(test_helpers::MockGateway::embedding_only()).await;
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&format!("http://{addr}"))
|
||||
.await
|
||||
.expect("connect failed");
|
||||
@@ -982,81 +1022,12 @@ mod tests {
|
||||
assert_eq!(prov.trust_level, crate::provenance::TrustLevel::Revoked);
|
||||
}
|
||||
|
||||
/// Helper to set up a mock gateway server that handles both embedding and inference.
|
||||
/// Returns the gateway address.
|
||||
mod extraction_test_helpers {
|
||||
use crate::db::schema::EMBEDDING_DIM;
|
||||
use llm_multiverse_proto::llm_multiverse::v1::model_gateway_service_server::{
|
||||
ModelGatewayService, ModelGatewayServiceServer,
|
||||
};
|
||||
use llm_multiverse_proto::llm_multiverse::v1::{
|
||||
GenerateEmbeddingRequest, GenerateEmbeddingResponse, InferenceRequest,
|
||||
InferenceResponse, IsModelReadyRequest, IsModelReadyResponse,
|
||||
StreamInferenceRequest, StreamInferenceResponse,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
pub struct FullMockGateway;
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl ModelGatewayService for FullMockGateway {
|
||||
type StreamInferenceStream =
|
||||
tokio_stream::Empty<Result<StreamInferenceResponse, Status>>;
|
||||
|
||||
async fn stream_inference(
|
||||
&self,
|
||||
_request: Request<StreamInferenceRequest>,
|
||||
) -> Result<Response<Self::StreamInferenceStream>, Status> {
|
||||
Err(Status::unimplemented("not needed"))
|
||||
}
|
||||
|
||||
async fn inference(
|
||||
&self,
|
||||
_request: Request<InferenceRequest>,
|
||||
) -> Result<Response<InferenceResponse>, Status> {
|
||||
Ok(Response::new(InferenceResponse {
|
||||
text: "SEGMENT: Extracted relevant text.\nCONFIDENCE: 0.85".to_string(),
|
||||
finish_reason: "stop".to_string(),
|
||||
tokens_used: 30,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn generate_embedding(
|
||||
&self,
|
||||
_request: Request<GenerateEmbeddingRequest>,
|
||||
) -> Result<Response<GenerateEmbeddingResponse>, Status> {
|
||||
let embedding = vec![0.10_f32; EMBEDDING_DIM];
|
||||
Ok(Response::new(GenerateEmbeddingResponse {
|
||||
embedding,
|
||||
dimensions: EMBEDDING_DIM as u32,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn is_model_ready(
|
||||
&self,
|
||||
_request: Request<IsModelReadyRequest>,
|
||||
) -> Result<Response<IsModelReadyResponse>, Status> {
|
||||
Err(Status::unimplemented("not needed"))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_full_mock_gateway() -> SocketAddr {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||
tokio::spawn(async move {
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(ModelGatewayServiceServer::new(FullMockGateway))
|
||||
.serve_with_incoming(incoming)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
addr
|
||||
}
|
||||
/// Start a full mock gateway (embedding + inference) for extraction tests.
|
||||
async fn start_full_mock_gateway() -> std::net::SocketAddr {
|
||||
test_helpers::start_mock_gateway(test_helpers::MockGateway::full(
|
||||
"SEGMENT: Extracted relevant text.\nCONFIDENCE: 0.85",
|
||||
))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Helper to populate DB with test data for extraction tests.
|
||||
@@ -1101,7 +1072,7 @@ mod tests {
|
||||
populate_test_db(&db);
|
||||
let db_arc = Arc::new(db);
|
||||
|
||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
||||
let addr = start_full_mock_gateway().await;
|
||||
let endpoint = format!("http://{addr}");
|
||||
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||
@@ -1168,7 +1139,7 @@ mod tests {
|
||||
populate_test_db(&db);
|
||||
let db_arc = Arc::new(db);
|
||||
|
||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
||||
let addr = start_full_mock_gateway().await;
|
||||
let endpoint = format!("http://{addr}");
|
||||
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||
@@ -1230,7 +1201,7 @@ mod tests {
|
||||
populate_test_db(&db);
|
||||
let db_arc = Arc::new(db);
|
||||
|
||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
||||
let addr = start_full_mock_gateway().await;
|
||||
let endpoint = format!("http://{addr}");
|
||||
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||
@@ -1293,7 +1264,7 @@ mod tests {
|
||||
populate_test_db(&db);
|
||||
let db_arc = Arc::new(db);
|
||||
|
||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
||||
let addr = start_full_mock_gateway().await;
|
||||
let endpoint = format!("http://{addr}");
|
||||
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||
@@ -1344,7 +1315,7 @@ mod tests {
|
||||
populate_test_db(&db);
|
||||
let db_arc = Arc::new(db);
|
||||
|
||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
||||
let addr = start_full_mock_gateway().await;
|
||||
let endpoint = format!("http://{addr}");
|
||||
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||
@@ -1413,7 +1384,7 @@ mod tests {
|
||||
populate_test_db(&db);
|
||||
let db_arc = Arc::new(db);
|
||||
|
||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
||||
let addr = start_full_mock_gateway().await;
|
||||
let endpoint = format!("http://{addr}");
|
||||
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||
@@ -1474,7 +1445,7 @@ mod tests {
|
||||
populate_test_db(&db);
|
||||
let db_arc = Arc::new(db);
|
||||
|
||||
let addr = extraction_test_helpers::start_full_mock_gateway().await;
|
||||
let addr = start_full_mock_gateway().await;
|
||||
let endpoint = format!("http://{addr}");
|
||||
|
||||
let embedding_client = crate::embedding::EmbeddingClient::connect(&endpoint)
|
||||
|
||||
Reference in New Issue
Block a user