通过Langchain创建你自己的Retriever(检索器)
构建Langchain中的自定义Retriever
背景
在构建你自己的RAG(Retrieval Augmented Generation)系统时,你需要从外部数据源检索数据并用于你的LLM应用。这个时候Retriever就负责根据用户的查询(query)检索出一系列相关的文档。通常来说,检索出的文档会通过预置的prompt模板传递给大模型,最终大模型会根据整合后的prompt给出一个合适的答案。
Langchain提供的接口
BaseRetriever
要实现你自己的Retriver,你需要继承BaseRetriever类并实现_get_relevant_documents函数。如果你需要异步的实现,还可以实现_aget_relevant_documents函数。函数内部的具体逻辑由你自己决定:本地缓存、传统搜索引擎、向量数据库等都是备选项。
RunnableLambda / RunnableGenerator
你也可以通过实现RunnableLambda或RunnableGenerator来实现你自己的Retriver。不过Langchain官方还是比较建议用BaseRetriever。这是对于一些可观测基础设施,它们可能会对继承了BaseRetriever的类有特殊的处理逻辑,而对Runnable没有。
此外,在一些Langchain API中,BaseRetriever和RunnableLambda的行为会有略微的不同。比如在astream_events API中 ,start event会是 on_retriever_start 而不是 on_chain_start。
例子
这里我使用一个基于ChromaDb的向量数据库Retriver来演示
部署Chroma服务端
简单包装一下Chroma的client
class MyVectorDBConnector:
def __init__(self,
embedding_fn: Callable = OpenAIEmbeddings().embed_documents,
collection_name: str = "test",
client=None
):
# client.reset()
# 创建一个 collection
if client is None:
client = chromadb.HttpClient(host=os.environ.get("CHROMA_SERVER_IP"),
port=8899,
)
self.client = client
self.collection = self.client.get_or_create_collection(name=collection_name)
self.embedding_fn = embedding_fn
def search(self, query, top_n=3):
"""检索向量数据库"""
results = self.collection.query(
query_embeddings=self.embedding_fn(texts=[query]),
n_results=top_n
)
return results
def add_documents(self, documents):
"""向 collection 中添加文档与向量"""
embeddings = self.embedding_fn(texts=documents)
self.collection.add(
embeddings=embeddings, # 每个文档的向量
documents=documents, # 文档的原文
ids=[f"id{i}" for i in range(len(documents))] # 每个文档的 id
)
def get_all_documents(self):
results = self.collection.get()
return results
def delete_all_documents(self):
self.client.delete_collection(name=self.collection.name)
实现Retriever
from typing import List
from chromadb import QueryResult
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from Retriever.MyVectorDbConnector import MyVectorDBConnector
class MyVectorStoreRetriever(BaseRetriever):
"""
基于向量数据库的 Retriever 实现
"""
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
"""Retriever 的同步实现"""
vector_db = MyVectorDBConnector()
query_result: QueryResult = vector_db.search(query)
if query_result:
docs = query_result["documents"]
if docs:
return [Document(page_content=doc[0]) for doc in docs]
return []
测试一下
if __name__ == "__main__":
documents = [
"科学家发现减少塑料污染的环保新材料",
"亚洲股市今日普遍上涨,投资者信心增强",
"中法签署新协议,促进双边艺术和教育合作",
"本地初创公司推出智能应用程序,助力老年人生活更便捷",
"国家足球队在国际友谊赛中取得压倒性胜利,球迷欢欣鼓舞",
]
vector_db = MyVectorDBConnector()
# 给向量数据库增加知识
vector_db.add_documents(documents)
my_retriever = MyVectorStoreRetriever()
# 查询
result = my_retriever.get_relevant_documents("体育新闻")
print(result)
结果
[Document(page_content='国家足球队在国际友谊赛中取得压倒性胜利,球迷欢欣鼓舞')]
注意:BaseRetriever本身就是一个Langchain的Runnable,所以它可以直接invoke
invoke_result = my_retriever.invoke("金融信息")
print(invoke_result)
结果
[Document(page_content='亚洲股市今日普遍上涨,投资者信心增强')]
更多推荐
所有评论(0)