Compare commits

...

2 Commits

Author SHA1 Message Date
4759accc63 Merge pull request 'fix: resolve tech debt from issue #30 review (#118)' (#129) from feature/issue-118-tech-debt-30-review into main 2026-03-10 07:46:56 +01:00
Pi Agent
1179a77467 fix: resolve tech debt from issue #30 review
- Add defense-in-depth validation for memory IDs before SQL interpolation
  in stage2, stage3, and stage4 IN clauses (validate alphanumeric/hyphen/underscore only)
- Make scoring weights (name, description, corpus) configurable via RetrievalConfig
  instead of compile-time constants, with defaults 0.3/0.3/0.4

Closes #118

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 07:46:19 +01:00
8 changed files with 311 additions and 27 deletions

View File

@@ -33,6 +33,7 @@
| #33 | Implement provenance tagging and poisoning protection | Phase 4 | `COMPLETED` | Rust | [issue-033.md](issue-033.md) |
| #114 | Tech debt: minor findings from issue #28 review | Phase 4 | `COMPLETED` | Rust | [issue-114.md](issue-114.md) |
| #116 | Tech debt: minor findings from issue #29 review | Phase 4 | `COMPLETED` | Rust | [issue-116.md](issue-116.md) |
| #118 | Tech debt: minor findings from issue #30 review | Phase 4 | `COMPLETED` | Rust | [issue-118.md](issue-118.md) |
## Status Legend

View File

@@ -0,0 +1,41 @@
# Issue #118: Tech debt: minor findings from issue #30 review
## Summary
Address two tech debt items from the staged retrieval pipeline review:
1. **Validate memory IDs before SQL interpolation** - Add defense-in-depth validation that memory IDs contain only safe characters (alphanumeric, hyphens) before string-interpolating them into SQL IN clauses in stage2, stage3, and stage4.
2. **Make scoring weights configurable** - Move the compile-time constants `NAME_WEIGHT`, `DESCRIPTION_WEIGHT`, and `CORPUS_WEIGHT` from stage3.rs into `RetrievalConfig`, making them tunable without recompilation.
## Item 1: Validate memory IDs before SQL interpolation
### Approach
Since the IDs originate from database query results (not user input) and are UUID-like strings, full parameterization via temp tables adds complexity without proportionate security benefit. Instead, add a validation function that asserts all IDs contain only safe characters (`[a-zA-Z0-9-]`) before interpolation, returning an error if any ID fails validation.
### Files changed
- `services/memory/src/retrieval/mod.rs` - Add `validate_memory_ids()` helper function
- `services/memory/src/retrieval/stage2.rs` - Call validation before building IN clause
- `services/memory/src/retrieval/stage3.rs` - Call validation before building IN clause
- `services/memory/src/retrieval/stage4.rs` - Call validation before building IN clause
## Item 2: Make scoring weights configurable
### Approach
Add three new fields to `RetrievalConfig`: `name_weight`, `description_weight`, `corpus_weight` with serde defaults matching the current constants (0.3, 0.3, 0.4). Thread these through `RetrievalParams` and into `stage3::load_and_rerank()`.
### Files changed
- `services/memory/src/config.rs` - Add weight fields to `RetrievalConfig`
- `services/memory/src/retrieval/mod.rs` - Add weight fields to `RetrievalParams`, update `from_config()`
- `services/memory/src/retrieval/stage3.rs` - Accept weights as parameters instead of using constants
- `services/memory/src/retrieval/pipeline.rs` - Pass weights from params to stage3
## Testing
- Add unit tests for `validate_memory_ids()` with valid and invalid inputs
- Add unit tests for configurable weights (custom values, TOML deserialization)
- Update existing stage3 tests to pass weight parameters

View File

@@ -15,6 +15,18 @@ pub struct RetrievalConfig {
/// Stage 4: Minimum cosine similarity score to include in final results (default: 0.3).
#[serde(default = "default_relevance_threshold")]
pub relevance_threshold: f32,
/// Weight for name embedding score in Stage 3 final score calculation (default: 0.3).
#[serde(default = "default_name_weight")]
pub name_weight: f32,
/// Weight for description embedding score in Stage 3 final score calculation (default: 0.3).
#[serde(default = "default_description_weight")]
pub description_weight: f32,
/// Weight for corpus embedding score in Stage 3 final score calculation (default: 0.4).
#[serde(default = "default_corpus_weight")]
pub corpus_weight: f32,
}
fn default_stage1_top_k() -> u32 {
@@ -29,6 +41,18 @@ fn default_relevance_threshold() -> f32 {
0.3
}
fn default_name_weight() -> f32 {
0.3
}
fn default_description_weight() -> f32 {
0.3
}
fn default_corpus_weight() -> f32 {
0.4
}
/// Configuration for the post-retrieval extraction step.
#[derive(Debug, Clone, Deserialize)]
pub struct ExtractionConfig {
@@ -124,6 +148,9 @@ impl Default for RetrievalConfig {
stage1_top_k: default_stage1_top_k(),
stage2_top_k: default_stage2_top_k(),
relevance_threshold: default_relevance_threshold(),
name_weight: default_name_weight(),
description_weight: default_description_weight(),
corpus_weight: default_corpus_weight(),
}
}
}
@@ -408,6 +435,39 @@ temperature = 0.3
assert!((config.extraction.temperature - 0.1).abs() < f32::EPSILON);
}
#[test]
fn test_retrieval_config_weights_defaults() {
let rc = RetrievalConfig::default();
assert!((rc.name_weight - 0.3).abs() < f32::EPSILON);
assert!((rc.description_weight - 0.3).abs() < f32::EPSILON);
assert!((rc.corpus_weight - 0.4).abs() < f32::EPSILON);
}
#[test]
fn test_retrieval_config_weights_from_toml() {
let dir = tempfile::tempdir().unwrap();
let config_path = dir.path().join("memory.toml");
std::fs::write(
&config_path,
r#"
host = "0.0.0.0"
port = 9999
db_path = "/var/lib/memory.duckdb"
[retrieval]
name_weight = 0.5
description_weight = 0.2
corpus_weight = 0.3
"#,
)
.unwrap();
let config = Config::load(Some(config_path.to_str().unwrap())).unwrap();
assert!((config.retrieval.name_weight - 0.5).abs() < f32::EPSILON);
assert!((config.retrieval.description_weight - 0.2).abs() < f32::EPSILON);
assert!((config.retrieval.corpus_weight - 0.3).abs() < f32::EPSILON);
}
#[test]
fn test_retrieval_config_uses_defaults_when_omitted() {
let dir = tempfile::tempdir().unwrap();

View File

@@ -69,6 +69,12 @@ pub struct RetrievalParams {
pub min_trust_level: Option<i32>,
/// Whether to exclude revoked memories from results.
pub exclude_revoked: bool,
/// Weight for name embedding score in Stage 3 final score calculation.
pub name_weight: f32,
/// Weight for description embedding score in Stage 3 final score calculation.
pub description_weight: f32,
/// Weight for corpus embedding score in Stage 3 final score calculation.
pub corpus_weight: f32,
}
impl RetrievalParams {
@@ -86,6 +92,9 @@ impl RetrievalParams {
result_limit: if result_limit == 0 { 5 } else { result_limit },
min_trust_level: None,
exclude_revoked: true,
name_weight: config.name_weight,
description_weight: config.description_weight,
corpus_weight: config.corpus_weight,
}
}
}
@@ -117,6 +126,38 @@ pub(crate) fn build_score_map(pairs: &[(String, f32)]) -> HashMap<String, f32> {
pairs.iter().cloned().collect()
}
/// Validate that all memory IDs contain only safe characters for SQL interpolation.
///
/// Memory IDs are expected to be UUID-like strings containing only alphanumeric
/// characters, hyphens, and underscores. This validation provides defense-in-depth
/// against SQL injection even though IDs originate from database query results.
pub(crate) fn validate_memory_ids(ids: &[String]) -> Result<(), crate::db::DbError> {
for id in ids {
if !id
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(crate::db::DbError::InvalidData(format!(
"Memory ID contains unsafe characters: {id}"
)));
}
}
Ok(())
}
/// Build a SQL IN clause from validated memory IDs.
///
/// Validates that all IDs contain only safe characters, then formats them
/// as a comma-separated list of quoted strings for use in SQL IN clauses.
pub(crate) fn build_id_list(ids: &[String]) -> Result<String, crate::db::DbError> {
validate_memory_ids(ids)?;
Ok(ids
.iter()
.map(|id| format!("'{id}'"))
.collect::<Vec<_>>()
.join(", "))
}
#[cfg(test)]
mod tests {
use super::*;
@@ -157,4 +198,65 @@ mod tests {
let err = RetrievalError::NoEmbeddingClient;
assert_eq!(err.to_string(), "no embedding client configured");
}
#[test]
fn test_retrieval_params_from_config_includes_weights() {
let config = RetrievalConfig {
name_weight: 0.5,
description_weight: 0.2,
corpus_weight: 0.3,
..RetrievalConfig::default()
};
let params = RetrievalParams::from_config(&config, None, 5);
assert!((params.name_weight - 0.5).abs() < f32::EPSILON);
assert!((params.description_weight - 0.2).abs() < f32::EPSILON);
assert!((params.corpus_weight - 0.3).abs() < f32::EPSILON);
}
#[test]
fn test_validate_memory_ids_valid() {
let ids = vec![
"mem-1".to_string(),
"abc123".to_string(),
"a-b-c_d".to_string(),
"550e8400-e29b-41d4-a716-446655440000".to_string(),
];
assert!(validate_memory_ids(&ids).is_ok());
}
#[test]
fn test_validate_memory_ids_empty() {
assert!(validate_memory_ids(&[]).is_ok());
}
#[test]
fn test_validate_memory_ids_rejects_quotes() {
let ids = vec!["mem-1' OR '1'='1".to_string()];
assert!(validate_memory_ids(&ids).is_err());
}
#[test]
fn test_validate_memory_ids_rejects_semicolons() {
let ids = vec!["mem-1; DROP TABLE memories".to_string()];
assert!(validate_memory_ids(&ids).is_err());
}
#[test]
fn test_validate_memory_ids_rejects_spaces() {
let ids = vec!["mem 1".to_string()];
assert!(validate_memory_ids(&ids).is_err());
}
#[test]
fn test_build_id_list_formats_correctly() {
let ids = vec!["mem-1".to_string(), "mem-2".to_string()];
let result = build_id_list(&ids).unwrap();
assert_eq!(result, "'mem-1', 'mem-2'");
}
#[test]
fn test_build_id_list_rejects_unsafe_ids() {
let ids = vec!["mem-1".to_string(), "bad'id".to_string()];
assert!(build_id_list(&ids).is_err());
}
}

View File

@@ -84,6 +84,11 @@ pub fn execute_pipeline(
let stage2_ids: Vec<String> = stage2_results.iter().map(|(id, _)| id.clone()).collect();
tracing::debug!("Stage 3: Corpus load and re-rank");
let scoring_weights = stage3::ScoringWeights {
name: params.name_weight,
description: params.description_weight,
corpus: params.corpus_weight,
};
let stage3_results = db.with_connection(|conn| {
stage3::load_and_rerank(
conn,
@@ -91,6 +96,7 @@ pub fn execute_pipeline(
&stage2_ids,
&name_scores,
&description_scores,
&scoring_weights,
)
})?;
@@ -185,6 +191,9 @@ mod tests {
result_limit: 10,
min_trust_level: None,
exclude_revoked: false,
name_weight: 0.3,
description_weight: 0.3,
corpus_weight: 0.4,
};
let results = execute_pipeline(&db, &query_vector, &params)
@@ -225,6 +234,9 @@ mod tests {
result_limit: 5,
min_trust_level: None,
exclude_revoked: false,
name_weight: 0.3,
description_weight: 0.3,
corpus_weight: 0.4,
};
let results = execute_pipeline(&db, &query_vector, &params)
@@ -249,6 +261,9 @@ mod tests {
result_limit: 10,
min_trust_level: None,
exclude_revoked: false,
name_weight: 0.3,
description_weight: 0.3,
corpus_weight: 0.4,
};
let results = execute_pipeline(&db, &query_vector, &params)
@@ -354,6 +369,9 @@ mod tests {
result_limit: 10,
min_trust_level: None,
exclude_revoked: false,
name_weight: 0.3,
description_weight: 0.3,
corpus_weight: 0.4,
};
// Check initial access_count

View File

@@ -6,6 +6,7 @@
use crate::db::DbError;
use crate::embedding::store::format_vector_literal;
use crate::retrieval::build_id_list;
use duckdb::Connection;
/// Execute Stage 2: Re-rank Stage 1 candidates by description embedding similarity.
@@ -65,12 +66,8 @@ pub fn rerank_by_description_with_provenance(
let vector_literal = format_vector_literal(query_vector)?;
// Build IN clause with quoted IDs
let id_list: String = candidate_ids
.iter()
.map(|id| format!("'{id}'"))
.collect::<Vec<_>>()
.join(", ");
// Build IN clause with validated, quoted IDs
let id_list = build_id_list(candidate_ids)?;
let tag_clause = if tag_filter.is_some() {
" AND e.memory_id IN (SELECT memory_id FROM memory_tags WHERE tag = ?)"

View File

@@ -9,17 +9,29 @@ use std::collections::HashMap;
use crate::db::DbError;
use crate::embedding::store::format_vector_literal;
use crate::provenance::store as provenance_store;
use crate::retrieval::RetrievalCandidate;
use crate::retrieval::{build_id_list, RetrievalCandidate};
use duckdb::Connection;
/// Weight for name embedding score in final score calculation.
const NAME_WEIGHT: f32 = 0.3;
/// Scoring weights for computing the final weighted score in Stage 3.
#[derive(Debug, Clone, Copy)]
pub struct ScoringWeights {
/// Weight for name embedding score.
pub name: f32,
/// Weight for description embedding score.
pub description: f32,
/// Weight for corpus embedding score.
pub corpus: f32,
}
/// Weight for description embedding score in final score calculation.
const DESCRIPTION_WEIGHT: f32 = 0.3;
/// Weight for corpus embedding score in final score calculation.
const CORPUS_WEIGHT: f32 = 0.4;
impl Default for ScoringWeights {
fn default() -> Self {
Self {
name: 0.3,
description: 0.3,
corpus: 0.4,
}
}
}
/// Execute Stage 3: Load full memory entries and re-rank by corpus embedding.
///
@@ -27,7 +39,7 @@ const CORPUS_WEIGHT: f32 = 0.4;
/// between the query vector and each candidate's corpus embedding. Combines name,
/// description, and corpus scores into a final score using a weighted average:
///
/// `final_score = 0.3 * name_score + 0.3 * description_score + 0.4 * corpus_score`
/// `final_score = name_weight * name_score + description_weight * description_score + corpus_weight * corpus_score`
///
/// Returns candidates as fully populated `RetrievalCandidate` structs sorted by
/// `final_score` descending.
@@ -39,22 +51,21 @@ const CORPUS_WEIGHT: f32 = 0.4;
/// * `candidate_ids` - Memory IDs from Stage 2
/// * `name_scores` - Name embedding scores from Stage 1
/// * `description_scores` - Description embedding scores from Stage 2
/// * `weights` - Scoring weights for name, description, and corpus scores
pub fn load_and_rerank(
conn: &Connection,
query_vector: &[f32],
candidate_ids: &[String],
name_scores: &HashMap<String, f32>,
description_scores: &HashMap<String, f32>,
weights: &ScoringWeights,
) -> Result<Vec<RetrievalCandidate>, DbError> {
if candidate_ids.is_empty() {
return Ok(Vec::new());
}
let id_list: String = candidate_ids
.iter()
.map(|id| format!("'{id}'"))
.collect::<Vec<_>>()
.join(", ");
// Build IN clause with validated, quoted IDs
let id_list = build_id_list(candidate_ids)?;
// Load full memory entries
let memories = load_memory_entries(conn, &id_list)?;
@@ -83,7 +94,7 @@ pub fn load_and_rerank(
let ns = name_scores.get(&mem.id).copied().unwrap_or(0.0);
let ds = description_scores.get(&mem.id).copied().unwrap_or(0.0);
let cs = corpus_scores.get(&mem.id).copied().unwrap_or(0.0);
let final_score = NAME_WEIGHT * ns + DESCRIPTION_WEIGHT * ds + CORPUS_WEIGHT * cs;
let final_score = weights.name * ns + weights.description * ds + weights.corpus * cs;
RetrievalCandidate {
memory_id: mem.id.clone(),
@@ -330,6 +341,7 @@ mod tests {
&["mem-1".to_string()],
&name_scores,
&desc_scores,
&ScoringWeights::default(),
)?;
assert_eq!(candidates.len(), 1);
@@ -362,6 +374,7 @@ mod tests {
&["mem-1".to_string()],
&name_scores,
&desc_scores,
&ScoringWeights::default(),
)?;
assert_eq!(candidates.len(), 1);
@@ -405,6 +418,7 @@ mod tests {
&["mem-1".to_string(), "mem-2".to_string()],
&name_scores,
&desc_scores,
&ScoringWeights::default(),
)?;
assert_eq!(candidates.len(), 2);
@@ -418,4 +432,59 @@ mod tests {
})
.expect("test failed");
}
#[test]
fn test_stage3_custom_weights() {
let db = DuckDbManager::in_memory().expect("DB creation failed");
db.with_connection(|conn| {
insert_full_memory(conn, "mem-1", "M1", "D1", "C1", 0.5, 0.5, 0.5, &[], &[]);
ensure_hnsw_index(conn)?;
let name_scores: HashMap<String, f32> = [("mem-1".to_string(), 0.8)].into();
let desc_scores: HashMap<String, f32> = [("mem-1".to_string(), 0.6)].into();
let query_vector = vec![0.5; EMBEDDING_DIM];
// Use custom weights that emphasize corpus score
let weights = ScoringWeights {
name: 0.1,
description: 0.1,
corpus: 0.8,
};
let candidates = load_and_rerank(
conn,
&query_vector,
&["mem-1".to_string()],
&name_scores,
&desc_scores,
&weights,
)?;
assert_eq!(candidates.len(), 1);
let c = &candidates[0];
let expected = 0.1 * 0.8 + 0.1 * 0.6 + 0.8 * c.corpus_score;
assert!(
(c.final_score - expected).abs() < 0.001,
"final_score {} should equal custom weighted average {}",
c.final_score,
expected
);
// Verify it differs from default weights
let default_candidates = load_and_rerank(
conn,
&query_vector,
&["mem-1".to_string()],
&name_scores,
&desc_scores,
&ScoringWeights::default(),
)?;
let default_expected = 0.3 * 0.8 + 0.3 * 0.6 + 0.4 * default_candidates[0].corpus_score;
assert!(
(default_candidates[0].final_score - default_expected).abs() < 0.001,
"default weighted score should use 0.3/0.3/0.4 weights"
);
Ok(())
})
.expect("test failed");
}
}

View File

@@ -4,7 +4,7 @@
//! the result limit, and updates access tracking metadata.
use crate::db::DbError;
use crate::retrieval::RetrievalCandidate;
use crate::retrieval::{build_id_list, RetrievalCandidate};
use duckdb::Connection;
/// Execute Stage 4: Apply relevance threshold and limit results.
@@ -39,11 +39,7 @@ pub fn update_access_tracking(conn: &Connection, memory_ids: &[String]) -> Resul
return Ok(());
}
let id_list: String = memory_ids
.iter()
.map(|id| format!("'{id}'"))
.collect::<Vec<_>>()
.join(", ");
let id_list = build_id_list(memory_ids)?;
let sql = format!(
"UPDATE memories \