Refactoring and add semantic cache

This commit is contained in:
shahondin1624
2026-02-15 17:09:55 +01:00
parent a0e36a7fe2
commit dcbb41d747
9 changed files with 145 additions and 68 deletions

View File

@@ -1,5 +1,6 @@
package de.shahondin1624.rag;
import de.shahondin1624.rag.cache.SemanticCache;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
@@ -7,7 +8,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -24,6 +24,8 @@ public class LocalRagService {
private final RagIndexer indexer;
private final RagDocumentSplitter splitter;
private final SemanticCache semanticCache;
public static synchronized LocalRagService getInstance() {
if (instance == null) {
instance = new LocalRagService();
@@ -35,6 +37,7 @@ public class LocalRagService {
this.storeManager = new RagStoreManager();
this.splitter = new RagDocumentSplitter();
this.indexer = new RagIndexer(storeManager, splitter);
this.semanticCache = new SemanticCache(50, 0.92);
initialize();
}
@@ -45,53 +48,64 @@ public class LocalRagService {
log.info("LocalRagService initialized successfully.");
}
public void indexDirectory(Path path) {
public void indexDirectory(final Path path) {
indexDirectory(path, null);
}
public void indexDirectory(Path path, Map<String, String> metadata) {
public void indexDirectory(final Path path, final Map<String, String> metadata) {
indexer.indexDirectory(path, metadata);
new RagFileWatcher(path, p -> indexer.indexFile(p, metadata), indexer::removeFile, indexer::isSupportedFile).start();
}
public void indexFile(Path path) {
public void indexFile(final Path path) {
indexFile(path, null);
}
public void indexFile(Path path, Map<String, String> metadata) {
public void indexFile(final Path path, final Map<String, String> metadata) {
indexer.indexFile(path, metadata);
}
public void removeFile(Path path) {
public void removeFile(final Path path) {
indexer.removeFile(path);
}
public void indexDocuments(List<String> docs) {
public void indexDocuments(final List<String> docs) {
indexDocuments(docs, null);
}
public void indexDocuments(List<String> docs, Map<String, String> metadata) {
public void indexDocuments(final List<String> docs, final Map<String, String> metadata) {
log.info("Indexing {} documents...", docs.size());
for (String docText : docs) {
List<TextSegment> segments = splitter.split(docText, metadata);
for (final String docText : docs) {
final List<TextSegment> segments = splitter.split(docText, metadata);
if (!segments.isEmpty()) {
var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
final var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
storeManager.getEmbeddingStore().addAll(embeddings, segments);
}
}
log.info("Indexing completed.");
}
public String search(String query) {
public String search(final String query) {
log.info("Searching for: {}", query);
var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content();
List<EmbeddingMatch<TextSegment>> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5);
final var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content();
String result = matches.stream()
final String cachedResult = semanticCache.get(queryEmbedding);
if (cachedResult != null) {
log.info("⚡ Cache Hit! Returning result without vector search.");
return cachedResult;
}
final List<EmbeddingMatch<TextSegment>> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5);
final String result = matches.stream()
.map(m -> m.embedded().text())
.collect(Collectors.joining("\n\n---\n\n"));
log.info("Search completed with {} matches.", matches.size());
if (!result.isBlank()) {
semanticCache.put(query, queryEmbedding, result);
}
return result;
}
@@ -99,7 +113,7 @@ public class LocalRagService {
return storeManager.getEmbeddingStore();
}
public void saveStore(Path path) {
public void saveStore(final Path path) {
storeManager.saveStore(path);
}
}

View File

@@ -23,13 +23,13 @@ public class RagDocumentSplitter {
* @param metadata The metadata to associate with each segment.
* @return A list of text segments.
*/
public List<TextSegment> split(String docText, Map<String, String> metadata) {
List<TextSegment> segments = new ArrayList<>();
Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata();
public List<TextSegment> split(final String docText, final Map<String, String> metadata) {
final List<TextSegment> segments = new ArrayList<>();
final Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata();
// Try to split by DSL-like blocks or Markdown headers
String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)");
for (String block : blocks) {
final String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)");
for (final String block : blocks) {
if (!block.isBlank()) {
segments.add(TextSegment.from(block.trim(), lcMetadata));
}
@@ -37,9 +37,9 @@ public class RagDocumentSplitter {
if (segments.isEmpty()) {
// Fallback to recursive splitter if no blocks found
DocumentSplitter splitter = DocumentSplitters.recursive(500, 50);
List<TextSegment> splitSegments = splitter.split(Document.from(docText));
for (TextSegment segment : splitSegments) {
final DocumentSplitter splitter = DocumentSplitters.recursive(500, 50);
final List<TextSegment> splitSegments = splitter.split(Document.from(docText));
for (final TextSegment segment : splitSegments) {
segments.add(TextSegment.from(segment.text(), lcMetadata));
}
}

View File

@@ -19,7 +19,7 @@ public class RagFileWatcher {
private final Consumer<Path> onFileDeleted;
private final Predicate<Path> fileFilter;
public RagFileWatcher(Path dir, Consumer<Path> onFileChanged, Consumer<Path> onFileDeleted, Predicate<Path> fileFilter) {
public RagFileWatcher(final Path dir, final Consumer<Path> onFileChanged, final Consumer<Path> onFileDeleted, final Predicate<Path> fileFilter) {
this.dir = dir;
this.onFileChanged = onFileChanged;
this.onFileDeleted = onFileDeleted;
@@ -28,17 +28,17 @@ public class RagFileWatcher {
public void start() {
CompletableFuture.runAsync(() -> {
try (WatchService watcher = FileSystems.getDefault().newWatchService()) {
try (final WatchService watcher = FileSystems.getDefault().newWatchService()) {
dir.register(watcher, StandardWatchEventKinds.ENTRY_CREATE,
StandardWatchEventKinds.ENTRY_DELETE,
StandardWatchEventKinds.ENTRY_MODIFY);
log.info("Started watching directory for changes: {}", dir);
while (!Thread.currentThread().isInterrupted()) {
WatchKey key = watcher.take();
for (WatchEvent<?> event : key.pollEvents()) {
WatchEvent.Kind<?> kind = event.kind();
Path fileName = (Path) event.context();
Path filePath = dir.resolve(fileName);
final WatchKey key = watcher.take();
for (final WatchEvent<?> event : key.pollEvents()) {
final WatchEvent.Kind<?> kind = event.kind();
final Path fileName = (Path) event.context();
final Path filePath = dir.resolve(fileName);
if (kind == StandardWatchEventKinds.ENTRY_CREATE || kind == StandardWatchEventKinds.ENTRY_MODIFY) {
if (Files.isRegularFile(filePath) && fileFilter.test(filePath)) {
@@ -52,7 +52,7 @@ public class RagFileWatcher {
break;
}
}
} catch (Exception e) {
} catch (final Exception e) {
log.error("Error in file watcher", e);
}
});

View File

@@ -24,39 +24,39 @@ public class RagIndexer {
private final RagDocumentSplitter splitter;
private final Map<Path, List<String>> fileToIds = new ConcurrentHashMap<>();
public RagIndexer(RagStoreManager storeManager, RagDocumentSplitter splitter) {
public RagIndexer(final RagStoreManager storeManager, final RagDocumentSplitter splitter) {
this.storeManager = storeManager;
this.splitter = splitter;
}
public void indexDirectory(Path path) {
public void indexDirectory(final Path path) {
indexDirectory(path, null);
}
public void indexDirectory(Path path, Map<String, String> metadata) {
public void indexDirectory(final Path path, final Map<String, String> metadata) {
log.info("Indexing directory: {}", path);
try (Stream<Path> paths = Files.walk(path)) {
try (final Stream<Path> paths = Files.walk(path)) {
paths.filter(Files::isRegularFile)
.filter(this::isSupportedFile)
.forEach(p -> indexFile(p, metadata));
} catch (IOException e) {
} catch (final IOException e) {
log.error("Failed to index directory: {}", path, e);
}
}
public void indexFile(Path path) {
public void indexFile(final Path path) {
indexFile(path, null);
}
public void indexFile(Path path, Map<String, String> extraMetadata) {
public void indexFile(final Path path, final Map<String, String> extraMetadata) {
removeFile(path);
try {
String content = Files.readString(path);
final String content = Files.readString(path);
if (content.isBlank()) {
return;
}
Map<String, String> metadata = new HashMap<>();
final Map<String, String> metadata = new HashMap<>();
metadata.put("file_path", path.toString());
if (extraMetadata != null) {
metadata.putAll(extraMetadata);
@@ -64,28 +64,28 @@ public class RagIndexer {
List<TextSegment> segments = splitter.split(content, metadata);
if (!segments.isEmpty()) {
List<Embedding> embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
List<String> ids = storeManager.getEmbeddingStore().addAll(embeddings, segments);
final List<Embedding> embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
final List<String> ids = storeManager.getEmbeddingStore().addAll(embeddings, segments);
fileToIds.put(path, ids);
log.info("Indexed file: {} ({} segments)", path, segments.size());
}
} catch (IOException e) {
} catch (final IOException e) {
log.error("Failed to read file: {}", path, e);
}
}
public void removeFile(Path path) {
List<String> ids = fileToIds.remove(path);
public void removeFile(final Path path) {
final List<String> ids = fileToIds.remove(path);
if (ids != null) {
for (String id : ids) {
for (final String id : ids) {
storeManager.getEmbeddingStore().remove(id);
}
log.info("Removed file from index: {}", path);
}
}
public boolean isSupportedFile(Path path) {
String s = path.toString();
public boolean isSupportedFile(final Path path) {
final String s = path.toString();
return s.endsWith(".md") || s.endsWith(".dsl") || s.endsWith(".txt");
}

View File

@@ -26,14 +26,14 @@ public class RagStoreManager {
this.embeddingModel = new AllMiniLmL6V2EmbeddingModel();
}
public void loadBundledStore(String resourcePath) {
try (InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) {
public void loadBundledStore(final String resourcePath) {
try (final InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) {
if (is != null) {
String json = new String(is.readAllBytes(), StandardCharsets.UTF_8);
final String json = new String(is.readAllBytes(), StandardCharsets.UTF_8);
this.embeddingStore = InMemoryEmbeddingStore.fromJson(json);
log.info("Loaded bundled embedding store from {}", resourcePath);
}
} catch (IOException e) {
} catch (final IOException e) {
log.error("Failed to load bundled embedding store", e);
}
}
@@ -44,13 +44,13 @@ public class RagStoreManager {
}
}
public void saveStore(Path path) {
public void saveStore(final Path path) {
if (embeddingStore == null) return;
try {
String json = embeddingStore.serializeToJson();
final String json = embeddingStore.serializeToJson();
Files.writeString(path, json);
log.info("Saved embedding store to {}", path);
} catch (IOException e) {
} catch (final IOException e) {
log.error("Failed to save embedding store", e);
}
}

View File

@@ -9,7 +9,7 @@ import java.nio.file.StandardOpenOption;
public class TestDataGenerator {
public static void main(String[] args) {
Path root = Paths.get("test-data");
final Path root = Paths.get("test-data");
try {
// 1. Create Directories
@@ -86,13 +86,13 @@ public class TestDataGenerator {
System.out.println("✅ Test data generated successfully in: " + root.toAbsolutePath());
} catch (IOException e) {
} catch (final IOException e) {
System.err.println("Failed to generate test data: " + e.getMessage());
e.printStackTrace();
}
}
private static void createFile(Path path, String content) throws IOException {
private static void createFile(final Path path, final String content) throws IOException {
Files.writeString(path, content, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
System.out.println("Created: " + path);
}

View File

@@ -0,0 +1,63 @@
package de.shahondin1624.rag.cache;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.store.embedding.CosineSimilarity;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class SemanticCache {
private final int maxEntries;
private final double similarityThreshold;
private final Map<CachedQuery, String> cache;
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
public SemanticCache(final int maxEntries, final double similarityThreshold) {
this.maxEntries = maxEntries;
this.similarityThreshold = similarityThreshold;
this.cache = new LinkedHashMap<>(maxEntries, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<CachedQuery, String> eldest) {
return size() > SemanticCache.this.maxEntries;
}
};
}
/**
* Tries to find a semantically similar query in the cache.
* @param incomingEmbedding The vector of the new user query.
* @return The cached string result, or null if no match found.
*/
public String get(final Embedding incomingEmbedding) {
lock.readLock().lock();
try {
// Linear scan of the cache (fast for <100 entries)
for (final Map.Entry<CachedQuery, String> entry : cache.entrySet()) {
final double similarity = CosineSimilarity.between(incomingEmbedding, entry.getKey().embedding);
if (similarity >= similarityThreshold) {
System.out.println("⚡ Semantic Cache Hit! (Similarity: " + similarity + ")");
return entry.getValue();
}
}
return null;
} finally {
lock.readLock().unlock();
}
}
public void put(final String originalQueryText, final Embedding embedding, final String result) {
lock.writeLock().lock();
try {
cache.put(new CachedQuery(originalQueryText, embedding), result);
} finally {
lock.writeLock().unlock();
}
}
// Wrapper class to store metadata if needed (e.g., timestamps)
private record CachedQuery(String text, Embedding embedding) {}
}

View File

@@ -56,15 +56,15 @@ public class RagEmbedTool extends McpValidatedTool {
}
@Override
public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map<String, Object> arguments) {
String content = (String) arguments.get("content");
public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map<String, Object> arguments) {
final String content = (String) arguments.get("content");
@SuppressWarnings("unchecked")
Map<String, Object> rawMetadata = (Map<String, Object>) arguments.get("metadata");
final Map<String, Object> rawMetadata = (Map<String, Object>) arguments.get("metadata");
Map<String, String> metadata = null;
if (rawMetadata != null) {
metadata = new HashMap<>();
for (Map.Entry<String, Object> entry : rawMetadata.entrySet()) {
for (final Map.Entry<String, Object> entry : rawMetadata.entrySet()) {
metadata.put(entry.getKey(), String.valueOf(entry.getValue()));
}
}
@@ -74,7 +74,7 @@ public class RagEmbedTool extends McpValidatedTool {
try {
LocalRagService.getInstance().indexDocuments(List.of(content), metadata);
return successResult("Information successfully embedded and indexed.");
} catch (Exception e) {
} catch (final Exception e) {
logger.error("Error during information embedding", e);
return error("Error during information embedding: " + e.getMessage());
}

View File

@@ -48,14 +48,14 @@ public class RagSearchTool extends McpValidatedTool {
}
@Override
public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map<String, Object> arguments) {
String query = (String) arguments.get("query");
public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map<String, Object> arguments) {
final String query = (String) arguments.get("query");
logger.debug("RagSearchTool called with query: {}", query);
try {
String result = LocalRagService.getInstance().search(query);
final String result = LocalRagService.getInstance().search(query);
return successResult(result);
} catch (Exception e) {
} catch (final Exception e) {
logger.error("Error during RAG search", e);
return error("Error during RAG search: " + e.getMessage());
}