Refactoring and add semantic cache

This commit is contained in:
shahondin1624
2026-02-15 17:09:55 +01:00
parent a0e36a7fe2
commit dcbb41d747
9 changed files with 145 additions and 68 deletions

View File

@@ -1,5 +1,6 @@
package de.shahondin1624.rag; package de.shahondin1624.rag;
import de.shahondin1624.rag.cache.SemanticCache;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
@@ -7,7 +8,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.nio.file.Path; import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -24,6 +24,8 @@ public class LocalRagService {
private final RagIndexer indexer; private final RagIndexer indexer;
private final RagDocumentSplitter splitter; private final RagDocumentSplitter splitter;
private final SemanticCache semanticCache;
public static synchronized LocalRagService getInstance() { public static synchronized LocalRagService getInstance() {
if (instance == null) { if (instance == null) {
instance = new LocalRagService(); instance = new LocalRagService();
@@ -35,6 +37,7 @@ public class LocalRagService {
this.storeManager = new RagStoreManager(); this.storeManager = new RagStoreManager();
this.splitter = new RagDocumentSplitter(); this.splitter = new RagDocumentSplitter();
this.indexer = new RagIndexer(storeManager, splitter); this.indexer = new RagIndexer(storeManager, splitter);
this.semanticCache = new SemanticCache(50, 0.92);
initialize(); initialize();
} }
@@ -45,53 +48,64 @@ public class LocalRagService {
log.info("LocalRagService initialized successfully."); log.info("LocalRagService initialized successfully.");
} }
public void indexDirectory(Path path) { public void indexDirectory(final Path path) {
indexDirectory(path, null); indexDirectory(path, null);
} }
public void indexDirectory(Path path, Map<String, String> metadata) { public void indexDirectory(final Path path, final Map<String, String> metadata) {
indexer.indexDirectory(path, metadata); indexer.indexDirectory(path, metadata);
new RagFileWatcher(path, p -> indexer.indexFile(p, metadata), indexer::removeFile, indexer::isSupportedFile).start(); new RagFileWatcher(path, p -> indexer.indexFile(p, metadata), indexer::removeFile, indexer::isSupportedFile).start();
} }
public void indexFile(Path path) { public void indexFile(final Path path) {
indexFile(path, null); indexFile(path, null);
} }
public void indexFile(Path path, Map<String, String> metadata) { public void indexFile(final Path path, final Map<String, String> metadata) {
indexer.indexFile(path, metadata); indexer.indexFile(path, metadata);
} }
public void removeFile(Path path) { public void removeFile(final Path path) {
indexer.removeFile(path); indexer.removeFile(path);
} }
public void indexDocuments(List<String> docs) { public void indexDocuments(final List<String> docs) {
indexDocuments(docs, null); indexDocuments(docs, null);
} }
public void indexDocuments(List<String> docs, Map<String, String> metadata) { public void indexDocuments(final List<String> docs, final Map<String, String> metadata) {
log.info("Indexing {} documents...", docs.size()); log.info("Indexing {} documents...", docs.size());
for (String docText : docs) { for (final String docText : docs) {
List<TextSegment> segments = splitter.split(docText, metadata); final List<TextSegment> segments = splitter.split(docText, metadata);
if (!segments.isEmpty()) { if (!segments.isEmpty()) {
var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content(); final var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
storeManager.getEmbeddingStore().addAll(embeddings, segments); storeManager.getEmbeddingStore().addAll(embeddings, segments);
} }
} }
log.info("Indexing completed."); log.info("Indexing completed.");
} }
public String search(String query) { public String search(final String query) {
log.info("Searching for: {}", query); log.info("Searching for: {}", query);
var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content(); final var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content();
List<EmbeddingMatch<TextSegment>> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5);
String result = matches.stream() final String cachedResult = semanticCache.get(queryEmbedding);
if (cachedResult != null) {
log.info("⚡ Cache Hit! Returning result without vector search.");
return cachedResult;
}
final List<EmbeddingMatch<TextSegment>> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5);
final String result = matches.stream()
.map(m -> m.embedded().text()) .map(m -> m.embedded().text())
.collect(Collectors.joining("\n\n---\n\n")); .collect(Collectors.joining("\n\n---\n\n"));
log.info("Search completed with {} matches.", matches.size()); log.info("Search completed with {} matches.", matches.size());
if (!result.isBlank()) {
semanticCache.put(query, queryEmbedding, result);
}
return result; return result;
} }
@@ -99,7 +113,7 @@ public class LocalRagService {
return storeManager.getEmbeddingStore(); return storeManager.getEmbeddingStore();
} }
public void saveStore(Path path) { public void saveStore(final Path path) {
storeManager.saveStore(path); storeManager.saveStore(path);
} }
} }

View File

@@ -23,13 +23,13 @@ public class RagDocumentSplitter {
* @param metadata The metadata to associate with each segment. * @param metadata The metadata to associate with each segment.
* @return A list of text segments. * @return A list of text segments.
*/ */
public List<TextSegment> split(String docText, Map<String, String> metadata) { public List<TextSegment> split(final String docText, final Map<String, String> metadata) {
List<TextSegment> segments = new ArrayList<>(); final List<TextSegment> segments = new ArrayList<>();
Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata(); final Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata();
// Try to split by DSL-like blocks or Markdown headers // Try to split by DSL-like blocks or Markdown headers
String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)"); final String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)");
for (String block : blocks) { for (final String block : blocks) {
if (!block.isBlank()) { if (!block.isBlank()) {
segments.add(TextSegment.from(block.trim(), lcMetadata)); segments.add(TextSegment.from(block.trim(), lcMetadata));
} }
@@ -37,9 +37,9 @@ public class RagDocumentSplitter {
if (segments.isEmpty()) { if (segments.isEmpty()) {
// Fallback to recursive splitter if no blocks found // Fallback to recursive splitter if no blocks found
DocumentSplitter splitter = DocumentSplitters.recursive(500, 50); final DocumentSplitter splitter = DocumentSplitters.recursive(500, 50);
List<TextSegment> splitSegments = splitter.split(Document.from(docText)); final List<TextSegment> splitSegments = splitter.split(Document.from(docText));
for (TextSegment segment : splitSegments) { for (final TextSegment segment : splitSegments) {
segments.add(TextSegment.from(segment.text(), lcMetadata)); segments.add(TextSegment.from(segment.text(), lcMetadata));
} }
} }

View File

@@ -19,7 +19,7 @@ public class RagFileWatcher {
private final Consumer<Path> onFileDeleted; private final Consumer<Path> onFileDeleted;
private final Predicate<Path> fileFilter; private final Predicate<Path> fileFilter;
public RagFileWatcher(Path dir, Consumer<Path> onFileChanged, Consumer<Path> onFileDeleted, Predicate<Path> fileFilter) { public RagFileWatcher(final Path dir, final Consumer<Path> onFileChanged, final Consumer<Path> onFileDeleted, final Predicate<Path> fileFilter) {
this.dir = dir; this.dir = dir;
this.onFileChanged = onFileChanged; this.onFileChanged = onFileChanged;
this.onFileDeleted = onFileDeleted; this.onFileDeleted = onFileDeleted;
@@ -28,17 +28,17 @@ public class RagFileWatcher {
public void start() { public void start() {
CompletableFuture.runAsync(() -> { CompletableFuture.runAsync(() -> {
try (WatchService watcher = FileSystems.getDefault().newWatchService()) { try (final WatchService watcher = FileSystems.getDefault().newWatchService()) {
dir.register(watcher, StandardWatchEventKinds.ENTRY_CREATE, dir.register(watcher, StandardWatchEventKinds.ENTRY_CREATE,
StandardWatchEventKinds.ENTRY_DELETE, StandardWatchEventKinds.ENTRY_DELETE,
StandardWatchEventKinds.ENTRY_MODIFY); StandardWatchEventKinds.ENTRY_MODIFY);
log.info("Started watching directory for changes: {}", dir); log.info("Started watching directory for changes: {}", dir);
while (!Thread.currentThread().isInterrupted()) { while (!Thread.currentThread().isInterrupted()) {
WatchKey key = watcher.take(); final WatchKey key = watcher.take();
for (WatchEvent<?> event : key.pollEvents()) { for (final WatchEvent<?> event : key.pollEvents()) {
WatchEvent.Kind<?> kind = event.kind(); final WatchEvent.Kind<?> kind = event.kind();
Path fileName = (Path) event.context(); final Path fileName = (Path) event.context();
Path filePath = dir.resolve(fileName); final Path filePath = dir.resolve(fileName);
if (kind == StandardWatchEventKinds.ENTRY_CREATE || kind == StandardWatchEventKinds.ENTRY_MODIFY) { if (kind == StandardWatchEventKinds.ENTRY_CREATE || kind == StandardWatchEventKinds.ENTRY_MODIFY) {
if (Files.isRegularFile(filePath) && fileFilter.test(filePath)) { if (Files.isRegularFile(filePath) && fileFilter.test(filePath)) {
@@ -52,7 +52,7 @@ public class RagFileWatcher {
break; break;
} }
} }
} catch (Exception e) { } catch (final Exception e) {
log.error("Error in file watcher", e); log.error("Error in file watcher", e);
} }
}); });

View File

@@ -24,39 +24,39 @@ public class RagIndexer {
private final RagDocumentSplitter splitter; private final RagDocumentSplitter splitter;
private final Map<Path, List<String>> fileToIds = new ConcurrentHashMap<>(); private final Map<Path, List<String>> fileToIds = new ConcurrentHashMap<>();
public RagIndexer(RagStoreManager storeManager, RagDocumentSplitter splitter) { public RagIndexer(final RagStoreManager storeManager, final RagDocumentSplitter splitter) {
this.storeManager = storeManager; this.storeManager = storeManager;
this.splitter = splitter; this.splitter = splitter;
} }
public void indexDirectory(Path path) { public void indexDirectory(final Path path) {
indexDirectory(path, null); indexDirectory(path, null);
} }
public void indexDirectory(Path path, Map<String, String> metadata) { public void indexDirectory(final Path path, final Map<String, String> metadata) {
log.info("Indexing directory: {}", path); log.info("Indexing directory: {}", path);
try (Stream<Path> paths = Files.walk(path)) { try (final Stream<Path> paths = Files.walk(path)) {
paths.filter(Files::isRegularFile) paths.filter(Files::isRegularFile)
.filter(this::isSupportedFile) .filter(this::isSupportedFile)
.forEach(p -> indexFile(p, metadata)); .forEach(p -> indexFile(p, metadata));
} catch (IOException e) { } catch (final IOException e) {
log.error("Failed to index directory: {}", path, e); log.error("Failed to index directory: {}", path, e);
} }
} }
public void indexFile(Path path) { public void indexFile(final Path path) {
indexFile(path, null); indexFile(path, null);
} }
public void indexFile(Path path, Map<String, String> extraMetadata) { public void indexFile(final Path path, final Map<String, String> extraMetadata) {
removeFile(path); removeFile(path);
try { try {
String content = Files.readString(path); final String content = Files.readString(path);
if (content.isBlank()) { if (content.isBlank()) {
return; return;
} }
Map<String, String> metadata = new HashMap<>(); final Map<String, String> metadata = new HashMap<>();
metadata.put("file_path", path.toString()); metadata.put("file_path", path.toString());
if (extraMetadata != null) { if (extraMetadata != null) {
metadata.putAll(extraMetadata); metadata.putAll(extraMetadata);
@@ -64,28 +64,28 @@ public class RagIndexer {
List<TextSegment> segments = splitter.split(content, metadata); List<TextSegment> segments = splitter.split(content, metadata);
if (!segments.isEmpty()) { if (!segments.isEmpty()) {
List<Embedding> embeddings = storeManager.getEmbeddingModel().embedAll(segments).content(); final List<Embedding> embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
List<String> ids = storeManager.getEmbeddingStore().addAll(embeddings, segments); final List<String> ids = storeManager.getEmbeddingStore().addAll(embeddings, segments);
fileToIds.put(path, ids); fileToIds.put(path, ids);
log.info("Indexed file: {} ({} segments)", path, segments.size()); log.info("Indexed file: {} ({} segments)", path, segments.size());
} }
} catch (IOException e) { } catch (final IOException e) {
log.error("Failed to read file: {}", path, e); log.error("Failed to read file: {}", path, e);
} }
} }
public void removeFile(Path path) { public void removeFile(final Path path) {
List<String> ids = fileToIds.remove(path); final List<String> ids = fileToIds.remove(path);
if (ids != null) { if (ids != null) {
for (String id : ids) { for (final String id : ids) {
storeManager.getEmbeddingStore().remove(id); storeManager.getEmbeddingStore().remove(id);
} }
log.info("Removed file from index: {}", path); log.info("Removed file from index: {}", path);
} }
} }
public boolean isSupportedFile(Path path) { public boolean isSupportedFile(final Path path) {
String s = path.toString(); final String s = path.toString();
return s.endsWith(".md") || s.endsWith(".dsl") || s.endsWith(".txt"); return s.endsWith(".md") || s.endsWith(".dsl") || s.endsWith(".txt");
} }

View File

@@ -26,14 +26,14 @@ public class RagStoreManager {
this.embeddingModel = new AllMiniLmL6V2EmbeddingModel(); this.embeddingModel = new AllMiniLmL6V2EmbeddingModel();
} }
public void loadBundledStore(String resourcePath) { public void loadBundledStore(final String resourcePath) {
try (InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) { try (final InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) {
if (is != null) { if (is != null) {
String json = new String(is.readAllBytes(), StandardCharsets.UTF_8); final String json = new String(is.readAllBytes(), StandardCharsets.UTF_8);
this.embeddingStore = InMemoryEmbeddingStore.fromJson(json); this.embeddingStore = InMemoryEmbeddingStore.fromJson(json);
log.info("Loaded bundled embedding store from {}", resourcePath); log.info("Loaded bundled embedding store from {}", resourcePath);
} }
} catch (IOException e) { } catch (final IOException e) {
log.error("Failed to load bundled embedding store", e); log.error("Failed to load bundled embedding store", e);
} }
} }
@@ -44,13 +44,13 @@ public class RagStoreManager {
} }
} }
public void saveStore(Path path) { public void saveStore(final Path path) {
if (embeddingStore == null) return; if (embeddingStore == null) return;
try { try {
String json = embeddingStore.serializeToJson(); final String json = embeddingStore.serializeToJson();
Files.writeString(path, json); Files.writeString(path, json);
log.info("Saved embedding store to {}", path); log.info("Saved embedding store to {}", path);
} catch (IOException e) { } catch (final IOException e) {
log.error("Failed to save embedding store", e); log.error("Failed to save embedding store", e);
} }
} }

View File

@@ -9,7 +9,7 @@ import java.nio.file.StandardOpenOption;
public class TestDataGenerator { public class TestDataGenerator {
public static void main(String[] args) { public static void main(String[] args) {
Path root = Paths.get("test-data"); final Path root = Paths.get("test-data");
try { try {
// 1. Create Directories // 1. Create Directories
@@ -86,13 +86,13 @@ public class TestDataGenerator {
System.out.println("✅ Test data generated successfully in: " + root.toAbsolutePath()); System.out.println("✅ Test data generated successfully in: " + root.toAbsolutePath());
} catch (IOException e) { } catch (final IOException e) {
System.err.println("Failed to generate test data: " + e.getMessage()); System.err.println("Failed to generate test data: " + e.getMessage());
e.printStackTrace(); e.printStackTrace();
} }
} }
private static void createFile(Path path, String content) throws IOException { private static void createFile(final Path path, final String content) throws IOException {
Files.writeString(path, content, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING); Files.writeString(path, content, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
System.out.println("Created: " + path); System.out.println("Created: " + path);
} }

View File

@@ -0,0 +1,63 @@
package de.shahondin1624.rag.cache;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.store.embedding.CosineSimilarity;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class SemanticCache {
private final int maxEntries;
private final double similarityThreshold;
private final Map<CachedQuery, String> cache;
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
public SemanticCache(final int maxEntries, final double similarityThreshold) {
this.maxEntries = maxEntries;
this.similarityThreshold = similarityThreshold;
this.cache = new LinkedHashMap<>(maxEntries, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<CachedQuery, String> eldest) {
return size() > SemanticCache.this.maxEntries;
}
};
}
/**
* Tries to find a semantically similar query in the cache.
* @param incomingEmbedding The vector of the new user query.
* @return The cached string result, or null if no match found.
*/
public String get(final Embedding incomingEmbedding) {
lock.readLock().lock();
try {
// Linear scan of the cache (fast for <100 entries)
for (final Map.Entry<CachedQuery, String> entry : cache.entrySet()) {
final double similarity = CosineSimilarity.between(incomingEmbedding, entry.getKey().embedding);
if (similarity >= similarityThreshold) {
System.out.println("⚡ Semantic Cache Hit! (Similarity: " + similarity + ")");
return entry.getValue();
}
}
return null;
} finally {
lock.readLock().unlock();
}
}
public void put(final String originalQueryText, final Embedding embedding, final String result) {
lock.writeLock().lock();
try {
cache.put(new CachedQuery(originalQueryText, embedding), result);
} finally {
lock.writeLock().unlock();
}
}
// Wrapper class to store metadata if needed (e.g., timestamps)
private record CachedQuery(String text, Embedding embedding) {}
}

View File

@@ -56,15 +56,15 @@ public class RagEmbedTool extends McpValidatedTool {
} }
@Override @Override
public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map<String, Object> arguments) { public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map<String, Object> arguments) {
String content = (String) arguments.get("content"); final String content = (String) arguments.get("content");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Map<String, Object> rawMetadata = (Map<String, Object>) arguments.get("metadata"); final Map<String, Object> rawMetadata = (Map<String, Object>) arguments.get("metadata");
Map<String, String> metadata = null; Map<String, String> metadata = null;
if (rawMetadata != null) { if (rawMetadata != null) {
metadata = new HashMap<>(); metadata = new HashMap<>();
for (Map.Entry<String, Object> entry : rawMetadata.entrySet()) { for (final Map.Entry<String, Object> entry : rawMetadata.entrySet()) {
metadata.put(entry.getKey(), String.valueOf(entry.getValue())); metadata.put(entry.getKey(), String.valueOf(entry.getValue()));
} }
} }
@@ -74,7 +74,7 @@ public class RagEmbedTool extends McpValidatedTool {
try { try {
LocalRagService.getInstance().indexDocuments(List.of(content), metadata); LocalRagService.getInstance().indexDocuments(List.of(content), metadata);
return successResult("Information successfully embedded and indexed."); return successResult("Information successfully embedded and indexed.");
} catch (Exception e) { } catch (final Exception e) {
logger.error("Error during information embedding", e); logger.error("Error during information embedding", e);
return error("Error during information embedding: " + e.getMessage()); return error("Error during information embedding: " + e.getMessage());
} }

View File

@@ -48,14 +48,14 @@ public class RagSearchTool extends McpValidatedTool {
} }
@Override @Override
public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map<String, Object> arguments) { public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map<String, Object> arguments) {
String query = (String) arguments.get("query"); final String query = (String) arguments.get("query");
logger.debug("RagSearchTool called with query: {}", query); logger.debug("RagSearchTool called with query: {}", query);
try { try {
String result = LocalRagService.getInstance().search(query); final String result = LocalRagService.getInstance().search(query);
return successResult(result); return successResult(result);
} catch (Exception e) { } catch (final Exception e) {
logger.error("Error during RAG search", e); logger.error("Error during RAG search", e);
return error("Error during RAG search: " + e.getMessage()); return error("Error during RAG search: " + e.getMessage());
} }