背景

在构建你自己的RAG(Retrieval Augmented Generation)系统时,你需要从外部数据源检索数据并用于你的LLM应用。这个时候Retriever就负责根据用户的查询(query)检索出一系列相关的文档。通常来说,检索出的文档会通过预置的prompt模板传递给大模型,最终大模型会根据整合后的prompt给出一个合适的答案。

Langchain提供的接口

BaseRetriever

要实现你自己的Retriver,你需要继承BaseRetriever类并实现_get_relevant_documents函数。如果你需要异步的实现,还可以实现_aget_relevant_documents函数。函数内部的具体逻辑由你自己决定:本地缓存、传统搜索引擎、向量数据库等都是备选项。

RunnableLambda / RunnableGenerator

你也可以通过实现RunnableLambda或RunnableGenerator来实现你自己的Retriver。不过Langchain官方还是比较建议用BaseRetriever。这是对于一些可观测基础设施,它们可能会对继承了BaseRetriever的类有特殊的处理逻辑,而对Runnable没有。
此外,在一些Langchain API中,BaseRetriever和RunnableLambda的行为会有略微的不同。比如在astream_events API中 ,start event会是 on_retriever_start 而不是 on_chain_start。

例子

这里我使用一个基于ChromaDb的向量数据库Retriver来演示
部署Chroma服务端

简单包装一下Chroma的client

class MyVectorDBConnector:
    def __init__(self,
                 embedding_fn: Callable = OpenAIEmbeddings().embed_documents,
                 collection_name: str = "test",
                 client=None
                 ):
        # client.reset()
        # 创建一个 collection
        if client is None:
            client = chromadb.HttpClient(host=os.environ.get("CHROMA_SERVER_IP"),
                                         port=8899,
                                         )

        self.client = client
        self.collection = self.client.get_or_create_collection(name=collection_name)
        self.embedding_fn = embedding_fn

    def search(self, query, top_n=3):
        """检索向量数据库"""
        results = self.collection.query(
            query_embeddings=self.embedding_fn(texts=[query]),
            n_results=top_n
        )
        return results

    def add_documents(self, documents):
        """向 collection 中添加文档与向量"""
        embeddings = self.embedding_fn(texts=documents)
        self.collection.add(
            embeddings=embeddings,  # 每个文档的向量
            documents=documents,  # 文档的原文
            ids=[f"id{i}" for i in range(len(documents))]  # 每个文档的 id
        )

    def get_all_documents(self):
        results = self.collection.get()
        return results

    def delete_all_documents(self):
        self.client.delete_collection(name=self.collection.name)

实现Retriever

from typing import List
from chromadb import QueryResult
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from Retriever.MyVectorDbConnector import MyVectorDBConnector


class MyVectorStoreRetriever(BaseRetriever):
    """
    基于向量数据库的 Retriever 实现
    """

    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
        """Retriever 的同步实现"""

        vector_db = MyVectorDBConnector()
        query_result: QueryResult = vector_db.search(query)
        if query_result:
            docs = query_result["documents"]
            if docs:
                return [Document(page_content=doc[0]) for doc in docs]
        return []

测试一下

if __name__ == "__main__":
    documents = [
        "科学家发现减少塑料污染的环保新材料",
        "亚洲股市今日普遍上涨,投资者信心增强",
        "中法签署新协议,促进双边艺术和教育合作",
        "本地初创公司推出智能应用程序,助力老年人生活更便捷",
        "国家足球队在国际友谊赛中取得压倒性胜利,球迷欢欣鼓舞",
    ]
    vector_db = MyVectorDBConnector()
    # 给向量数据库增加知识
    vector_db.add_documents(documents)
    
    my_retriever = MyVectorStoreRetriever()
    # 查询
    result = my_retriever.get_relevant_documents("体育新闻")
    print(result)

结果

[Document(page_content='国家足球队在国际友谊赛中取得压倒性胜利,球迷欢欣鼓舞')]

注意:BaseRetriever本身就是一个Langchain的Runnable,所以它可以直接invoke

invoke_result = my_retriever.invoke("金融信息")
    print(invoke_result)

结果

[Document(page_content='亚洲股市今日普遍上涨,投资者信心增强')]
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐