大模型 · 2025-01-31 0

langchain4j 使用 RAG

1.pom

<properties>
    <maven.compiler.source>21</maven.compiler.source>
    <maven.compiler.target>21</maven.compiler.target>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <langchain4j.version>1.7.1</langchain4j.version>
</properties>

<dependencies>
    <!-- LangChain4j Core -->
    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j</artifactId>
        <version>${langchain4j.version}</version>
    </dependency>

    <!-- LangChain4j OpenAI -->
    <dependency>
        <groupId>dev.langchain4j</groupId>
        <artifactId>langchain4j-open-ai</artifactId>
        <version>${langchain4j.version}</version>
    </dependency>
</dependencies>

2.java

import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;

import java.util.List;

public class RagAssistant {

    private final OpenAiChatModel chatModel;
    private final EmbeddingModel embeddingModel;
    private final EmbeddingStore<TextSegment> embeddingStore;

    private final RagService service;

    public RagAssistant(String baseApiUrl,
                        String chatModelName,
                        String embeddingModelName,
                        String apiKey) {
        this.chatModel = OpenAiChatModel.builder()
                .baseUrl(baseApiUrl)
                .modelName(chatModelName)
                .apiKey(apiKey)
                .temperature(0.2)
                .maxTokens(2000)
                .build();

        this.embeddingModel = OpenAiEmbeddingModel.builder()
                .baseUrl(baseApiUrl)
                .modelName(embeddingModelName)
                .apiKey(apiKey)
                .build();

        this.embeddingStore = new InMemoryEmbeddingStore<>();

        ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingModel(embeddingModel)
                .embeddingStore(embeddingStore)
                // 返回最相关的 3 个结果
                .maxResults(3)
                // 最小相似度分数
                .minScore(0.0)
                .build();

        this.service = AiServices.builder(RagService.class)
                .chatModel(chatModel)
                .systemMessageProvider(req -> "你是一个知识检索助理。")
                .contentRetriever(contentRetriever)
                .build();
    }

    public void ingestText(String text) {
        Document doc = Document.from(text);
        List<TextSegment> segmentList = DocumentSplitters
                .recursive(60, 20)
                .split(doc);

        for (TextSegment segment : segmentList) {
            Embedding embedding = embeddingModel.embed(segment).content();
            embeddingStore.add(embedding, segment);
        }
    }

    public String answer(String question) {
        return service.ask(question);
    }

    public interface RagService {
        String ask(String userPrompt);
    }
}

测试类

public class RagMain {

    public static void main(String[] args) {
        // RAG
        RagAssistant rag = new RagAssistant(
                "https://dashscope.aliyuncs.com/compatible-mode/v1",
                "qwen-turbo",
                "text-embedding-v1",
                "sk-123456");
//            RagAssistant rag = new RagAssistant(
//                    "http://localhost:11434/v1",
//                    "deepseek-r1:14b",
//                    "nomic-embed-text:latest",
//                    null);

        // 直接注入一段知识文本
        rag.ingestText("""
                LangChain4j 是一个用于在 Java 中构建 LLM 应用的库,支持聊天、向量检索、工具调用等能力。
                RAG(检索增强生成)通过将问题与向量数据库中的相关文档匹配,把上下文提供给大模型,从而得到更准确的回答。
                """);

        String question = "什么是 RAG,它如何提升答案的准确性?";
        String response = rag.answer(question);
        System.out.println("RAG 答复:" + response);
    }
}