diff --git a/src/main/java/de/shahondin1624/rag/LocalRagService.java b/src/main/java/de/shahondin1624/rag/LocalRagService.java index 5076909..1831ac2 100644 --- a/src/main/java/de/shahondin1624/rag/LocalRagService.java +++ b/src/main/java/de/shahondin1624/rag/LocalRagService.java @@ -1,5 +1,6 @@ package de.shahondin1624.rag; +import de.shahondin1624.rag.cache.SemanticCache; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; @@ -7,7 +8,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.file.Path; -import java.nio.file.Paths; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -24,6 +24,8 @@ public class LocalRagService { private final RagIndexer indexer; private final RagDocumentSplitter splitter; + private final SemanticCache semanticCache; + public static synchronized LocalRagService getInstance() { if (instance == null) { instance = new LocalRagService(); @@ -35,6 +37,7 @@ public class LocalRagService { this.storeManager = new RagStoreManager(); this.splitter = new RagDocumentSplitter(); this.indexer = new RagIndexer(storeManager, splitter); + this.semanticCache = new SemanticCache(50, 0.92); initialize(); } @@ -45,53 +48,64 @@ public class LocalRagService { log.info("LocalRagService initialized successfully."); } - public void indexDirectory(Path path) { + public void indexDirectory(final Path path) { indexDirectory(path, null); } - public void indexDirectory(Path path, Map metadata) { + public void indexDirectory(final Path path, final Map metadata) { indexer.indexDirectory(path, metadata); new RagFileWatcher(path, p -> indexer.indexFile(p, metadata), indexer::removeFile, indexer::isSupportedFile).start(); } - public void indexFile(Path path) { + public void indexFile(final Path path) { indexFile(path, null); } - public void indexFile(Path path, Map metadata) { + public void indexFile(final Path path, final Map metadata) { indexer.indexFile(path, metadata); } - public void removeFile(Path path) { + public void removeFile(final Path path) { indexer.removeFile(path); } - public void indexDocuments(List docs) { + public void indexDocuments(final List docs) { indexDocuments(docs, null); } - public void indexDocuments(List docs, Map metadata) { + public void indexDocuments(final List docs, final Map metadata) { log.info("Indexing {} documents...", docs.size()); - for (String docText : docs) { - List segments = splitter.split(docText, metadata); + for (final String docText : docs) { + final List segments = splitter.split(docText, metadata); if (!segments.isEmpty()) { - var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content(); + final var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content(); storeManager.getEmbeddingStore().addAll(embeddings, segments); } } log.info("Indexing completed."); } - public String search(String query) { + public String search(final String query) { log.info("Searching for: {}", query); - var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content(); - List> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5); + final var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content(); - String result = matches.stream() + final String cachedResult = semanticCache.get(queryEmbedding); + if (cachedResult != null) { + log.info("⚡ Cache Hit! Returning result without vector search."); + return cachedResult; + } + + final List> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5); + + final String result = matches.stream() .map(m -> m.embedded().text()) .collect(Collectors.joining("\n\n---\n\n")); log.info("Search completed with {} matches.", matches.size()); + + if (!result.isBlank()) { + semanticCache.put(query, queryEmbedding, result); + } return result; } @@ -99,7 +113,7 @@ public class LocalRagService { return storeManager.getEmbeddingStore(); } - public void saveStore(Path path) { + public void saveStore(final Path path) { storeManager.saveStore(path); } } diff --git a/src/main/java/de/shahondin1624/rag/RagDocumentSplitter.java b/src/main/java/de/shahondin1624/rag/RagDocumentSplitter.java index 0033d65..875da83 100644 --- a/src/main/java/de/shahondin1624/rag/RagDocumentSplitter.java +++ b/src/main/java/de/shahondin1624/rag/RagDocumentSplitter.java @@ -23,13 +23,13 @@ public class RagDocumentSplitter { * @param metadata The metadata to associate with each segment. * @return A list of text segments. */ - public List split(String docText, Map metadata) { - List segments = new ArrayList<>(); - Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata(); + public List split(final String docText, final Map metadata) { + final List segments = new ArrayList<>(); + final Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata(); // Try to split by DSL-like blocks or Markdown headers - String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)"); - for (String block : blocks) { + final String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)"); + for (final String block : blocks) { if (!block.isBlank()) { segments.add(TextSegment.from(block.trim(), lcMetadata)); } @@ -37,9 +37,9 @@ public class RagDocumentSplitter { if (segments.isEmpty()) { // Fallback to recursive splitter if no blocks found - DocumentSplitter splitter = DocumentSplitters.recursive(500, 50); - List splitSegments = splitter.split(Document.from(docText)); - for (TextSegment segment : splitSegments) { + final DocumentSplitter splitter = DocumentSplitters.recursive(500, 50); + final List splitSegments = splitter.split(Document.from(docText)); + for (final TextSegment segment : splitSegments) { segments.add(TextSegment.from(segment.text(), lcMetadata)); } } diff --git a/src/main/java/de/shahondin1624/rag/RagFileWatcher.java b/src/main/java/de/shahondin1624/rag/RagFileWatcher.java index 0613218..f914484 100644 --- a/src/main/java/de/shahondin1624/rag/RagFileWatcher.java +++ b/src/main/java/de/shahondin1624/rag/RagFileWatcher.java @@ -19,7 +19,7 @@ public class RagFileWatcher { private final Consumer onFileDeleted; private final Predicate fileFilter; - public RagFileWatcher(Path dir, Consumer onFileChanged, Consumer onFileDeleted, Predicate fileFilter) { + public RagFileWatcher(final Path dir, final Consumer onFileChanged, final Consumer onFileDeleted, final Predicate fileFilter) { this.dir = dir; this.onFileChanged = onFileChanged; this.onFileDeleted = onFileDeleted; @@ -28,17 +28,17 @@ public class RagFileWatcher { public void start() { CompletableFuture.runAsync(() -> { - try (WatchService watcher = FileSystems.getDefault().newWatchService()) { + try (final WatchService watcher = FileSystems.getDefault().newWatchService()) { dir.register(watcher, StandardWatchEventKinds.ENTRY_CREATE, StandardWatchEventKinds.ENTRY_DELETE, StandardWatchEventKinds.ENTRY_MODIFY); log.info("Started watching directory for changes: {}", dir); while (!Thread.currentThread().isInterrupted()) { - WatchKey key = watcher.take(); - for (WatchEvent event : key.pollEvents()) { - WatchEvent.Kind kind = event.kind(); - Path fileName = (Path) event.context(); - Path filePath = dir.resolve(fileName); + final WatchKey key = watcher.take(); + for (final WatchEvent event : key.pollEvents()) { + final WatchEvent.Kind kind = event.kind(); + final Path fileName = (Path) event.context(); + final Path filePath = dir.resolve(fileName); if (kind == StandardWatchEventKinds.ENTRY_CREATE || kind == StandardWatchEventKinds.ENTRY_MODIFY) { if (Files.isRegularFile(filePath) && fileFilter.test(filePath)) { @@ -52,7 +52,7 @@ public class RagFileWatcher { break; } } - } catch (Exception e) { + } catch (final Exception e) { log.error("Error in file watcher", e); } }); diff --git a/src/main/java/de/shahondin1624/rag/RagIndexer.java b/src/main/java/de/shahondin1624/rag/RagIndexer.java index 2c8ecd5..7f9a0bc 100644 --- a/src/main/java/de/shahondin1624/rag/RagIndexer.java +++ b/src/main/java/de/shahondin1624/rag/RagIndexer.java @@ -24,39 +24,39 @@ public class RagIndexer { private final RagDocumentSplitter splitter; private final Map> fileToIds = new ConcurrentHashMap<>(); - public RagIndexer(RagStoreManager storeManager, RagDocumentSplitter splitter) { + public RagIndexer(final RagStoreManager storeManager, final RagDocumentSplitter splitter) { this.storeManager = storeManager; this.splitter = splitter; } - public void indexDirectory(Path path) { + public void indexDirectory(final Path path) { indexDirectory(path, null); } - public void indexDirectory(Path path, Map metadata) { + public void indexDirectory(final Path path, final Map metadata) { log.info("Indexing directory: {}", path); - try (Stream paths = Files.walk(path)) { + try (final Stream paths = Files.walk(path)) { paths.filter(Files::isRegularFile) .filter(this::isSupportedFile) .forEach(p -> indexFile(p, metadata)); - } catch (IOException e) { + } catch (final IOException e) { log.error("Failed to index directory: {}", path, e); } } - public void indexFile(Path path) { + public void indexFile(final Path path) { indexFile(path, null); } - public void indexFile(Path path, Map extraMetadata) { + public void indexFile(final Path path, final Map extraMetadata) { removeFile(path); try { - String content = Files.readString(path); + final String content = Files.readString(path); if (content.isBlank()) { return; } - Map metadata = new HashMap<>(); + final Map metadata = new HashMap<>(); metadata.put("file_path", path.toString()); if (extraMetadata != null) { metadata.putAll(extraMetadata); @@ -64,28 +64,28 @@ public class RagIndexer { List segments = splitter.split(content, metadata); if (!segments.isEmpty()) { - List embeddings = storeManager.getEmbeddingModel().embedAll(segments).content(); - List ids = storeManager.getEmbeddingStore().addAll(embeddings, segments); + final List embeddings = storeManager.getEmbeddingModel().embedAll(segments).content(); + final List ids = storeManager.getEmbeddingStore().addAll(embeddings, segments); fileToIds.put(path, ids); log.info("Indexed file: {} ({} segments)", path, segments.size()); } - } catch (IOException e) { + } catch (final IOException e) { log.error("Failed to read file: {}", path, e); } } - public void removeFile(Path path) { - List ids = fileToIds.remove(path); + public void removeFile(final Path path) { + final List ids = fileToIds.remove(path); if (ids != null) { - for (String id : ids) { + for (final String id : ids) { storeManager.getEmbeddingStore().remove(id); } log.info("Removed file from index: {}", path); } } - public boolean isSupportedFile(Path path) { - String s = path.toString(); + public boolean isSupportedFile(final Path path) { + final String s = path.toString(); return s.endsWith(".md") || s.endsWith(".dsl") || s.endsWith(".txt"); } diff --git a/src/main/java/de/shahondin1624/rag/RagStoreManager.java b/src/main/java/de/shahondin1624/rag/RagStoreManager.java index 760cf2a..a249170 100644 --- a/src/main/java/de/shahondin1624/rag/RagStoreManager.java +++ b/src/main/java/de/shahondin1624/rag/RagStoreManager.java @@ -26,14 +26,14 @@ public class RagStoreManager { this.embeddingModel = new AllMiniLmL6V2EmbeddingModel(); } - public void loadBundledStore(String resourcePath) { - try (InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) { + public void loadBundledStore(final String resourcePath) { + try (final InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) { if (is != null) { - String json = new String(is.readAllBytes(), StandardCharsets.UTF_8); + final String json = new String(is.readAllBytes(), StandardCharsets.UTF_8); this.embeddingStore = InMemoryEmbeddingStore.fromJson(json); log.info("Loaded bundled embedding store from {}", resourcePath); } - } catch (IOException e) { + } catch (final IOException e) { log.error("Failed to load bundled embedding store", e); } } @@ -44,13 +44,13 @@ public class RagStoreManager { } } - public void saveStore(Path path) { + public void saveStore(final Path path) { if (embeddingStore == null) return; try { - String json = embeddingStore.serializeToJson(); + final String json = embeddingStore.serializeToJson(); Files.writeString(path, json); log.info("Saved embedding store to {}", path); - } catch (IOException e) { + } catch (final IOException e) { log.error("Failed to save embedding store", e); } } diff --git a/src/main/java/de/shahondin1624/rag/TestDataGenerator.java b/src/main/java/de/shahondin1624/rag/TestDataGenerator.java index 33ee34d..ff64497 100644 --- a/src/main/java/de/shahondin1624/rag/TestDataGenerator.java +++ b/src/main/java/de/shahondin1624/rag/TestDataGenerator.java @@ -9,7 +9,7 @@ import java.nio.file.StandardOpenOption; public class TestDataGenerator { public static void main(String[] args) { - Path root = Paths.get("test-data"); + final Path root = Paths.get("test-data"); try { // 1. Create Directories @@ -86,13 +86,13 @@ public class TestDataGenerator { System.out.println("✅ Test data generated successfully in: " + root.toAbsolutePath()); - } catch (IOException e) { + } catch (final IOException e) { System.err.println("Failed to generate test data: " + e.getMessage()); e.printStackTrace(); } } - private static void createFile(Path path, String content) throws IOException { + private static void createFile(final Path path, final String content) throws IOException { Files.writeString(path, content, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING); System.out.println("Created: " + path); } diff --git a/src/main/java/de/shahondin1624/rag/cache/SemanticCache.java b/src/main/java/de/shahondin1624/rag/cache/SemanticCache.java new file mode 100644 index 0000000..e1c5bfe --- /dev/null +++ b/src/main/java/de/shahondin1624/rag/cache/SemanticCache.java @@ -0,0 +1,63 @@ +package de.shahondin1624.rag.cache; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.store.embedding.CosineSimilarity; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +public class SemanticCache { + + private final int maxEntries; + private final double similarityThreshold; + + private final Map cache; + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + + public SemanticCache(final int maxEntries, final double similarityThreshold) { + this.maxEntries = maxEntries; + this.similarityThreshold = similarityThreshold; + + this.cache = new LinkedHashMap<>(maxEntries, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > SemanticCache.this.maxEntries; + } + }; + } + + /** + * Tries to find a semantically similar query in the cache. + * @param incomingEmbedding The vector of the new user query. + * @return The cached string result, or null if no match found. + */ + public String get(final Embedding incomingEmbedding) { + lock.readLock().lock(); + try { + // Linear scan of the cache (fast for <100 entries) + for (final Map.Entry entry : cache.entrySet()) { + final double similarity = CosineSimilarity.between(incomingEmbedding, entry.getKey().embedding); + if (similarity >= similarityThreshold) { + System.out.println("⚡ Semantic Cache Hit! (Similarity: " + similarity + ")"); + return entry.getValue(); + } + } + return null; + } finally { + lock.readLock().unlock(); + } + } + + public void put(final String originalQueryText, final Embedding embedding, final String result) { + lock.writeLock().lock(); + try { + cache.put(new CachedQuery(originalQueryText, embedding), result); + } finally { + lock.writeLock().unlock(); + } + } + + // Wrapper class to store metadata if needed (e.g., timestamps) + private record CachedQuery(String text, Embedding embedding) {} +} \ No newline at end of file diff --git a/src/main/java/de/shahondin1624/rag/tooling/RagEmbedTool.java b/src/main/java/de/shahondin1624/rag/tooling/RagEmbedTool.java index 0afdc6a..6367ac2 100644 --- a/src/main/java/de/shahondin1624/rag/tooling/RagEmbedTool.java +++ b/src/main/java/de/shahondin1624/rag/tooling/RagEmbedTool.java @@ -56,15 +56,15 @@ public class RagEmbedTool extends McpValidatedTool { } @Override - public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map arguments) { - String content = (String) arguments.get("content"); + public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map arguments) { + final String content = (String) arguments.get("content"); @SuppressWarnings("unchecked") - Map rawMetadata = (Map) arguments.get("metadata"); + final Map rawMetadata = (Map) arguments.get("metadata"); Map metadata = null; if (rawMetadata != null) { metadata = new HashMap<>(); - for (Map.Entry entry : rawMetadata.entrySet()) { + for (final Map.Entry entry : rawMetadata.entrySet()) { metadata.put(entry.getKey(), String.valueOf(entry.getValue())); } } @@ -74,7 +74,7 @@ public class RagEmbedTool extends McpValidatedTool { try { LocalRagService.getInstance().indexDocuments(List.of(content), metadata); return successResult("Information successfully embedded and indexed."); - } catch (Exception e) { + } catch (final Exception e) { logger.error("Error during information embedding", e); return error("Error during information embedding: " + e.getMessage()); } diff --git a/src/main/java/de/shahondin1624/rag/tooling/RagSearchTool.java b/src/main/java/de/shahondin1624/rag/tooling/RagSearchTool.java index d1a5a8a..5a99c14 100644 --- a/src/main/java/de/shahondin1624/rag/tooling/RagSearchTool.java +++ b/src/main/java/de/shahondin1624/rag/tooling/RagSearchTool.java @@ -48,14 +48,14 @@ public class RagSearchTool extends McpValidatedTool { } @Override - public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map arguments) { - String query = (String) arguments.get("query"); + public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map arguments) { + final String query = (String) arguments.get("query"); logger.debug("RagSearchTool called with query: {}", query); try { - String result = LocalRagService.getInstance().search(query); + final String result = LocalRagService.getInstance().search(query); return successResult(result); - } catch (Exception e) { + } catch (final Exception e) { logger.error("Error during RAG search", e); return error("Error during RAG search: " + e.getMessage()); }