Refactoring and add semantic cache
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package de.shahondin1624.rag;
|
||||
|
||||
import de.shahondin1624.rag.cache.SemanticCache;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
@@ -7,7 +8,6 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -24,6 +24,8 @@ public class LocalRagService {
|
||||
private final RagIndexer indexer;
|
||||
private final RagDocumentSplitter splitter;
|
||||
|
||||
private final SemanticCache semanticCache;
|
||||
|
||||
public static synchronized LocalRagService getInstance() {
|
||||
if (instance == null) {
|
||||
instance = new LocalRagService();
|
||||
@@ -35,6 +37,7 @@ public class LocalRagService {
|
||||
this.storeManager = new RagStoreManager();
|
||||
this.splitter = new RagDocumentSplitter();
|
||||
this.indexer = new RagIndexer(storeManager, splitter);
|
||||
this.semanticCache = new SemanticCache(50, 0.92);
|
||||
initialize();
|
||||
}
|
||||
|
||||
@@ -45,53 +48,64 @@ public class LocalRagService {
|
||||
log.info("LocalRagService initialized successfully.");
|
||||
}
|
||||
|
||||
public void indexDirectory(Path path) {
|
||||
public void indexDirectory(final Path path) {
|
||||
indexDirectory(path, null);
|
||||
}
|
||||
|
||||
public void indexDirectory(Path path, Map<String, String> metadata) {
|
||||
public void indexDirectory(final Path path, final Map<String, String> metadata) {
|
||||
indexer.indexDirectory(path, metadata);
|
||||
new RagFileWatcher(path, p -> indexer.indexFile(p, metadata), indexer::removeFile, indexer::isSupportedFile).start();
|
||||
}
|
||||
|
||||
public void indexFile(Path path) {
|
||||
public void indexFile(final Path path) {
|
||||
indexFile(path, null);
|
||||
}
|
||||
|
||||
public void indexFile(Path path, Map<String, String> metadata) {
|
||||
public void indexFile(final Path path, final Map<String, String> metadata) {
|
||||
indexer.indexFile(path, metadata);
|
||||
}
|
||||
|
||||
public void removeFile(Path path) {
|
||||
public void removeFile(final Path path) {
|
||||
indexer.removeFile(path);
|
||||
}
|
||||
|
||||
public void indexDocuments(List<String> docs) {
|
||||
public void indexDocuments(final List<String> docs) {
|
||||
indexDocuments(docs, null);
|
||||
}
|
||||
|
||||
public void indexDocuments(List<String> docs, Map<String, String> metadata) {
|
||||
public void indexDocuments(final List<String> docs, final Map<String, String> metadata) {
|
||||
log.info("Indexing {} documents...", docs.size());
|
||||
for (String docText : docs) {
|
||||
List<TextSegment> segments = splitter.split(docText, metadata);
|
||||
for (final String docText : docs) {
|
||||
final List<TextSegment> segments = splitter.split(docText, metadata);
|
||||
if (!segments.isEmpty()) {
|
||||
var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
|
||||
final var embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
|
||||
storeManager.getEmbeddingStore().addAll(embeddings, segments);
|
||||
}
|
||||
}
|
||||
log.info("Indexing completed.");
|
||||
}
|
||||
|
||||
public String search(String query) {
|
||||
public String search(final String query) {
|
||||
log.info("Searching for: {}", query);
|
||||
var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content();
|
||||
List<EmbeddingMatch<TextSegment>> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5);
|
||||
final var queryEmbedding = storeManager.getEmbeddingModel().embed(query).content();
|
||||
|
||||
String result = matches.stream()
|
||||
final String cachedResult = semanticCache.get(queryEmbedding);
|
||||
if (cachedResult != null) {
|
||||
log.info("⚡ Cache Hit! Returning result without vector search.");
|
||||
return cachedResult;
|
||||
}
|
||||
|
||||
final List<EmbeddingMatch<TextSegment>> matches = storeManager.getEmbeddingStore().findRelevant(queryEmbedding, 5);
|
||||
|
||||
final String result = matches.stream()
|
||||
.map(m -> m.embedded().text())
|
||||
.collect(Collectors.joining("\n\n---\n\n"));
|
||||
|
||||
log.info("Search completed with {} matches.", matches.size());
|
||||
|
||||
if (!result.isBlank()) {
|
||||
semanticCache.put(query, queryEmbedding, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -99,7 +113,7 @@ public class LocalRagService {
|
||||
return storeManager.getEmbeddingStore();
|
||||
}
|
||||
|
||||
public void saveStore(Path path) {
|
||||
public void saveStore(final Path path) {
|
||||
storeManager.saveStore(path);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,13 +23,13 @@ public class RagDocumentSplitter {
|
||||
* @param metadata The metadata to associate with each segment.
|
||||
* @return A list of text segments.
|
||||
*/
|
||||
public List<TextSegment> split(String docText, Map<String, String> metadata) {
|
||||
List<TextSegment> segments = new ArrayList<>();
|
||||
Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata();
|
||||
public List<TextSegment> split(final String docText, final Map<String, String> metadata) {
|
||||
final List<TextSegment> segments = new ArrayList<>();
|
||||
final Metadata lcMetadata = metadata != null ? Metadata.from(metadata) : new Metadata();
|
||||
|
||||
// Try to split by DSL-like blocks or Markdown headers
|
||||
String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)");
|
||||
for (String block : blocks) {
|
||||
final String[] blocks = docText.split("\n(?=## )|\n(?=\\w+\\s*\\{)");
|
||||
for (final String block : blocks) {
|
||||
if (!block.isBlank()) {
|
||||
segments.add(TextSegment.from(block.trim(), lcMetadata));
|
||||
}
|
||||
@@ -37,9 +37,9 @@ public class RagDocumentSplitter {
|
||||
|
||||
if (segments.isEmpty()) {
|
||||
// Fallback to recursive splitter if no blocks found
|
||||
DocumentSplitter splitter = DocumentSplitters.recursive(500, 50);
|
||||
List<TextSegment> splitSegments = splitter.split(Document.from(docText));
|
||||
for (TextSegment segment : splitSegments) {
|
||||
final DocumentSplitter splitter = DocumentSplitters.recursive(500, 50);
|
||||
final List<TextSegment> splitSegments = splitter.split(Document.from(docText));
|
||||
for (final TextSegment segment : splitSegments) {
|
||||
segments.add(TextSegment.from(segment.text(), lcMetadata));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ public class RagFileWatcher {
|
||||
private final Consumer<Path> onFileDeleted;
|
||||
private final Predicate<Path> fileFilter;
|
||||
|
||||
public RagFileWatcher(Path dir, Consumer<Path> onFileChanged, Consumer<Path> onFileDeleted, Predicate<Path> fileFilter) {
|
||||
public RagFileWatcher(final Path dir, final Consumer<Path> onFileChanged, final Consumer<Path> onFileDeleted, final Predicate<Path> fileFilter) {
|
||||
this.dir = dir;
|
||||
this.onFileChanged = onFileChanged;
|
||||
this.onFileDeleted = onFileDeleted;
|
||||
@@ -28,17 +28,17 @@ public class RagFileWatcher {
|
||||
|
||||
public void start() {
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try (WatchService watcher = FileSystems.getDefault().newWatchService()) {
|
||||
try (final WatchService watcher = FileSystems.getDefault().newWatchService()) {
|
||||
dir.register(watcher, StandardWatchEventKinds.ENTRY_CREATE,
|
||||
StandardWatchEventKinds.ENTRY_DELETE,
|
||||
StandardWatchEventKinds.ENTRY_MODIFY);
|
||||
log.info("Started watching directory for changes: {}", dir);
|
||||
while (!Thread.currentThread().isInterrupted()) {
|
||||
WatchKey key = watcher.take();
|
||||
for (WatchEvent<?> event : key.pollEvents()) {
|
||||
WatchEvent.Kind<?> kind = event.kind();
|
||||
Path fileName = (Path) event.context();
|
||||
Path filePath = dir.resolve(fileName);
|
||||
final WatchKey key = watcher.take();
|
||||
for (final WatchEvent<?> event : key.pollEvents()) {
|
||||
final WatchEvent.Kind<?> kind = event.kind();
|
||||
final Path fileName = (Path) event.context();
|
||||
final Path filePath = dir.resolve(fileName);
|
||||
|
||||
if (kind == StandardWatchEventKinds.ENTRY_CREATE || kind == StandardWatchEventKinds.ENTRY_MODIFY) {
|
||||
if (Files.isRegularFile(filePath) && fileFilter.test(filePath)) {
|
||||
@@ -52,7 +52,7 @@ public class RagFileWatcher {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
log.error("Error in file watcher", e);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -24,39 +24,39 @@ public class RagIndexer {
|
||||
private final RagDocumentSplitter splitter;
|
||||
private final Map<Path, List<String>> fileToIds = new ConcurrentHashMap<>();
|
||||
|
||||
public RagIndexer(RagStoreManager storeManager, RagDocumentSplitter splitter) {
|
||||
public RagIndexer(final RagStoreManager storeManager, final RagDocumentSplitter splitter) {
|
||||
this.storeManager = storeManager;
|
||||
this.splitter = splitter;
|
||||
}
|
||||
|
||||
public void indexDirectory(Path path) {
|
||||
public void indexDirectory(final Path path) {
|
||||
indexDirectory(path, null);
|
||||
}
|
||||
|
||||
public void indexDirectory(Path path, Map<String, String> metadata) {
|
||||
public void indexDirectory(final Path path, final Map<String, String> metadata) {
|
||||
log.info("Indexing directory: {}", path);
|
||||
try (Stream<Path> paths = Files.walk(path)) {
|
||||
try (final Stream<Path> paths = Files.walk(path)) {
|
||||
paths.filter(Files::isRegularFile)
|
||||
.filter(this::isSupportedFile)
|
||||
.forEach(p -> indexFile(p, metadata));
|
||||
} catch (IOException e) {
|
||||
} catch (final IOException e) {
|
||||
log.error("Failed to index directory: {}", path, e);
|
||||
}
|
||||
}
|
||||
|
||||
public void indexFile(Path path) {
|
||||
public void indexFile(final Path path) {
|
||||
indexFile(path, null);
|
||||
}
|
||||
|
||||
public void indexFile(Path path, Map<String, String> extraMetadata) {
|
||||
public void indexFile(final Path path, final Map<String, String> extraMetadata) {
|
||||
removeFile(path);
|
||||
try {
|
||||
String content = Files.readString(path);
|
||||
final String content = Files.readString(path);
|
||||
if (content.isBlank()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, String> metadata = new HashMap<>();
|
||||
final Map<String, String> metadata = new HashMap<>();
|
||||
metadata.put("file_path", path.toString());
|
||||
if (extraMetadata != null) {
|
||||
metadata.putAll(extraMetadata);
|
||||
@@ -64,28 +64,28 @@ public class RagIndexer {
|
||||
|
||||
List<TextSegment> segments = splitter.split(content, metadata);
|
||||
if (!segments.isEmpty()) {
|
||||
List<Embedding> embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
|
||||
List<String> ids = storeManager.getEmbeddingStore().addAll(embeddings, segments);
|
||||
final List<Embedding> embeddings = storeManager.getEmbeddingModel().embedAll(segments).content();
|
||||
final List<String> ids = storeManager.getEmbeddingStore().addAll(embeddings, segments);
|
||||
fileToIds.put(path, ids);
|
||||
log.info("Indexed file: {} ({} segments)", path, segments.size());
|
||||
}
|
||||
} catch (IOException e) {
|
||||
} catch (final IOException e) {
|
||||
log.error("Failed to read file: {}", path, e);
|
||||
}
|
||||
}
|
||||
|
||||
public void removeFile(Path path) {
|
||||
List<String> ids = fileToIds.remove(path);
|
||||
public void removeFile(final Path path) {
|
||||
final List<String> ids = fileToIds.remove(path);
|
||||
if (ids != null) {
|
||||
for (String id : ids) {
|
||||
for (final String id : ids) {
|
||||
storeManager.getEmbeddingStore().remove(id);
|
||||
}
|
||||
log.info("Removed file from index: {}", path);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isSupportedFile(Path path) {
|
||||
String s = path.toString();
|
||||
public boolean isSupportedFile(final Path path) {
|
||||
final String s = path.toString();
|
||||
return s.endsWith(".md") || s.endsWith(".dsl") || s.endsWith(".txt");
|
||||
}
|
||||
|
||||
|
||||
@@ -26,14 +26,14 @@ public class RagStoreManager {
|
||||
this.embeddingModel = new AllMiniLmL6V2EmbeddingModel();
|
||||
}
|
||||
|
||||
public void loadBundledStore(String resourcePath) {
|
||||
try (InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) {
|
||||
public void loadBundledStore(final String resourcePath) {
|
||||
try (final InputStream is = getClass().getClassLoader().getResourceAsStream(resourcePath)) {
|
||||
if (is != null) {
|
||||
String json = new String(is.readAllBytes(), StandardCharsets.UTF_8);
|
||||
final String json = new String(is.readAllBytes(), StandardCharsets.UTF_8);
|
||||
this.embeddingStore = InMemoryEmbeddingStore.fromJson(json);
|
||||
log.info("Loaded bundled embedding store from {}", resourcePath);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
} catch (final IOException e) {
|
||||
log.error("Failed to load bundled embedding store", e);
|
||||
}
|
||||
}
|
||||
@@ -44,13 +44,13 @@ public class RagStoreManager {
|
||||
}
|
||||
}
|
||||
|
||||
public void saveStore(Path path) {
|
||||
public void saveStore(final Path path) {
|
||||
if (embeddingStore == null) return;
|
||||
try {
|
||||
String json = embeddingStore.serializeToJson();
|
||||
final String json = embeddingStore.serializeToJson();
|
||||
Files.writeString(path, json);
|
||||
log.info("Saved embedding store to {}", path);
|
||||
} catch (IOException e) {
|
||||
} catch (final IOException e) {
|
||||
log.error("Failed to save embedding store", e);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import java.nio.file.StandardOpenOption;
|
||||
public class TestDataGenerator {
|
||||
|
||||
public static void main(String[] args) {
|
||||
Path root = Paths.get("test-data");
|
||||
final Path root = Paths.get("test-data");
|
||||
|
||||
try {
|
||||
// 1. Create Directories
|
||||
@@ -86,13 +86,13 @@ public class TestDataGenerator {
|
||||
|
||||
System.out.println("✅ Test data generated successfully in: " + root.toAbsolutePath());
|
||||
|
||||
} catch (IOException e) {
|
||||
} catch (final IOException e) {
|
||||
System.err.println("Failed to generate test data: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
private static void createFile(Path path, String content) throws IOException {
|
||||
private static void createFile(final Path path, final String content) throws IOException {
|
||||
Files.writeString(path, content, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
|
||||
System.out.println("Created: " + path);
|
||||
}
|
||||
|
||||
63
src/main/java/de/shahondin1624/rag/cache/SemanticCache.java
vendored
Normal file
63
src/main/java/de/shahondin1624/rag/cache/SemanticCache.java
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
package de.shahondin1624.rag.cache;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
|
||||
public class SemanticCache {
|
||||
|
||||
private final int maxEntries;
|
||||
private final double similarityThreshold;
|
||||
|
||||
private final Map<CachedQuery, String> cache;
|
||||
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
|
||||
|
||||
public SemanticCache(final int maxEntries, final double similarityThreshold) {
|
||||
this.maxEntries = maxEntries;
|
||||
this.similarityThreshold = similarityThreshold;
|
||||
|
||||
this.cache = new LinkedHashMap<>(maxEntries, 0.75f, true) {
|
||||
@Override
|
||||
protected boolean removeEldestEntry(Map.Entry<CachedQuery, String> eldest) {
|
||||
return size() > SemanticCache.this.maxEntries;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to find a semantically similar query in the cache.
|
||||
* @param incomingEmbedding The vector of the new user query.
|
||||
* @return The cached string result, or null if no match found.
|
||||
*/
|
||||
public String get(final Embedding incomingEmbedding) {
|
||||
lock.readLock().lock();
|
||||
try {
|
||||
// Linear scan of the cache (fast for <100 entries)
|
||||
for (final Map.Entry<CachedQuery, String> entry : cache.entrySet()) {
|
||||
final double similarity = CosineSimilarity.between(incomingEmbedding, entry.getKey().embedding);
|
||||
if (similarity >= similarityThreshold) {
|
||||
System.out.println("⚡ Semantic Cache Hit! (Similarity: " + similarity + ")");
|
||||
return entry.getValue();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void put(final String originalQueryText, final Embedding embedding, final String result) {
|
||||
lock.writeLock().lock();
|
||||
try {
|
||||
cache.put(new CachedQuery(originalQueryText, embedding), result);
|
||||
} finally {
|
||||
lock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Wrapper class to store metadata if needed (e.g., timestamps)
|
||||
private record CachedQuery(String text, Embedding embedding) {}
|
||||
}
|
||||
@@ -56,15 +56,15 @@ public class RagEmbedTool extends McpValidatedTool {
|
||||
}
|
||||
|
||||
@Override
|
||||
public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map<String, Object> arguments) {
|
||||
String content = (String) arguments.get("content");
|
||||
public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map<String, Object> arguments) {
|
||||
final String content = (String) arguments.get("content");
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> rawMetadata = (Map<String, Object>) arguments.get("metadata");
|
||||
final Map<String, Object> rawMetadata = (Map<String, Object>) arguments.get("metadata");
|
||||
|
||||
Map<String, String> metadata = null;
|
||||
if (rawMetadata != null) {
|
||||
metadata = new HashMap<>();
|
||||
for (Map.Entry<String, Object> entry : rawMetadata.entrySet()) {
|
||||
for (final Map.Entry<String, Object> entry : rawMetadata.entrySet()) {
|
||||
metadata.put(entry.getKey(), String.valueOf(entry.getValue()));
|
||||
}
|
||||
}
|
||||
@@ -74,7 +74,7 @@ public class RagEmbedTool extends McpValidatedTool {
|
||||
try {
|
||||
LocalRagService.getInstance().indexDocuments(List.of(content), metadata);
|
||||
return successResult("Information successfully embedded and indexed.");
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
logger.error("Error during information embedding", e);
|
||||
return error("Error during information embedding: " + e.getMessage());
|
||||
}
|
||||
|
||||
@@ -48,14 +48,14 @@ public class RagSearchTool extends McpValidatedTool {
|
||||
}
|
||||
|
||||
@Override
|
||||
public McpSchema.CallToolResult callValidated(McpSchema.CallToolRequest request, Map<String, Object> arguments) {
|
||||
String query = (String) arguments.get("query");
|
||||
public McpSchema.CallToolResult callValidated(final McpSchema.CallToolRequest request, final Map<String, Object> arguments) {
|
||||
final String query = (String) arguments.get("query");
|
||||
logger.debug("RagSearchTool called with query: {}", query);
|
||||
|
||||
try {
|
||||
String result = LocalRagService.getInstance().search(query);
|
||||
final String result = LocalRagService.getInstance().search(query);
|
||||
return successResult(result);
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
logger.error("Error during RAG search", e);
|
||||
return error("Error during RAG search: " + e.getMessage());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user