【RAG入门必备技能】Faiss框架使用与FaissRetriever实现
faiss是一个Facebook AI团队开源的库,全称为Facebook AI Similarity Search,该开源库针对高维空间中的海量数据(稠密向量),提供了高效且可靠的相似性聚类和检索方法,可支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库官方资源地址https://github.com/facebookresearch/faiss。
Faiss介绍
faiss是一个Facebook AI团队开源的库,全称为Facebook AI Similarity Search,该开源库针对高维空间中的海量数据(稠密向量),提供了高效且可靠的相似性聚类和检索方法,可支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库
官方资源地址https://github.com/facebookresearch/faiss
Faiss基础依赖
1)矩阵计算框架:Faiss与计算资源之间需要一个外部依赖框架,这个框架是一个矩阵计算框架,官方默认配置安装的是OpenBlas,另外也可以用Intel的MKL,相比于OpenBlas使用MKL作为框架进行编译可以提高一定的稳定性。
2)OpenMP:如果向量之间的相似性搜索是逐条进行的那计算效率会非常低,而Faiss内部实现使用了OpenMP,可以以batch的形式来进行搜素,实现计算效率的提升。
Faiss工作数据流
在使用Faiss进行query向量的相似性搜索之前,需要将原始的向量集构建封装成一个索引文件(index file)并缓存在内存中,提供实时的查询计算。在第一次构建索引文件的时候,需要经过Train和Add两个过程。后续如果有新的向量需要被添加到索引文件的话还可以有一个Add操作从而实现增量build索引。
Faiss的核心
Faiss本质上是一个向量(矢量)数据库。进行搜索时,基础是原始向量数据库,基本单位是单个向量,默认输入一个向量x,返回和x最相似的k个向量。其中的核心就是索引(index对象),Index继承了一组向量库,作用是对原始向量集进行预处理和封装,一般操作包括train和add,可以建成一个索引对象缓存在计算机内存中。所有向量在建立前需要明确向量的维度d,大多数的索引还需要训练阶段来分析向量的分布(除了IndexFlatL2)。当索引被建立就可以进行后续的search操作了。
Train:
目的:生成原向量中心点,残差(向量中心点的差值)向量中心点,部分预计算的距离
流程:
1)把原始向量分成M个子空间,针对每个子空间训练中心点(如果每个子空间的中心点为n,则pq可表达n的M次方个中心点)。
2)查找向量对应的中心点
3)向量减去对应的中心点生成残差向量
4)针对残差向量生成二级量化器。
Search:
Search操作时索引的重要部分,search方法涉及实际的相似度计算,返回的检索结果包括两个矩阵,分别为xq中元素与近邻的距离大小和近邻向量的索引序号。
使用方法
Faiss是为稠密向量提供高效相似度搜索的框架(Facebook AI Research),选择索引方式是faiss的核心内容,faiss 三个最常用的索引是:IndexFlatL2, IndexIVFFlat,IndexIVFPQ。
-
IndexFlatL2/ IndexFlatIP为最基础的精确查找。
-
IndexIVFFlat称为倒排文件索引,是使用K-means建立聚类中心,通过查询最近的聚类中心,比较聚类中的所有向量得到相似的向量,是一种加速搜索方法的索引。
- IndexIVFPQ是一种减少内存的索引方式,IndexFlatL2和IndexIVFFlat都会全量存储所有的向量在内存中,面对大数据量,faiss提供一种基于Product Quantizer(乘积量化)的压缩算法编码向量到指定字节数来减少内存占用。但这种情况下,存储的向量是压缩过的,所以查询的距离也是近似的。
下面以为代码的方式讲解Faiss使用
构建句子向量表示模型
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import faiss
import numpy as np
import os
from tqdm import tqdm
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
在下面的代码片段中,我们建立了一个SemanticEmbedding类,它使用预先训练的 MPNet 模型将文本编码为向量。这里我选用了sbert的多语言模型paraphrase-multilingual-mpnet-base-v2
class SemanticEmbedding:
def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embedding(self, sentences):
# Tokenize sentences
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.detach().numpy()
调用方法如下:
model = SemanticEmbedding(r'I:\pretrained_models\bert\english\paraphrase-multilingual-mpnet-base-v2')
a = model.get_embedding("我喜欢打篮球")
print(a)
print(a.shape)
构建 Faiss 索引
Faiss 根据用户所需的功能提供多种不同索引选项。在下面代码中我们使用 IndexFlatIP,因为它的距离机制是内积,对于规范化嵌入而言,它与余弦相似度相同,值越大越相似。
在下面的代码中,我们创建了一个FaissIdx类,该类使用嵌入向量大小(在本例中为 768)初始化我们的索引,并使用计数器跟踪文档的 ID。请注意,Faiss 索引当前为空 — 它不包含任何文档向量。
然后向该类添加了两个方法来添加和搜索文档。为了实现这些方法,使用 Faiss 的 API 来获取文档的嵌入并搜索查询向量。请注意,搜索 API self.index.search 也接受参数 k 作为输入,它定义要返回的文档向量数量。在本例中,我们告诉该方法返回前三个文档向量,使用我们定义的 doc_map 数据结构来返回一个人类可读的文档。
class FaissIdx:
def __init__(self, model, dim=768):
self.index = faiss.IndexFlatIP(dim)
# Maintaining the document data
self.doc_map = dict()
self.model = model
self.ctr = 0
def add_doc(self, document_text):
self.index.add(self.model.get_embedding(document_text))
self.doc_map[self.ctr] = document_text # store the original document text
self.ctr += 1
def search_doc(self, query, k=3):
D, I = self.index.search(self.model.get_embedding(query), k)
return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map]
测试索引
index = FaissIdx(model)
index.add_doc("笔记本电脑")
index.add_doc("医生的办公室")
result=index.search_doc("个人电脑")
print(result)
设置索引后,我们可以添加其他文档并对其进行搜索。在下面的代码中,我们添加文档“笔记本电脑”和“医生办公室”,然后搜索“PC 电脑”。请注意,“笔记本电脑”的相似度很高,而“医生办公室”的相似度很低,这是有道理的。请记住,余弦相似度的范围从 -1 到 1,其中 1 表示完全相似。
[{'笔记本电脑': 0.9483323}, {'医生的办公室': 0.39533424}]
构建语料索引
# 加载测试文档
data=pd.read_json('../../data/zh_refine.json', lines=True)[:50]
print(data)
print(data.columns)
for documents in tqdm(data['positive'],total=len(data)):
for document in documents:
index.add_doc(document)
for documents in tqdm(data['negative'],total=len(data)):
for document in documents:
index.add_doc(document)
result=index.search_doc("2022年特斯拉交付量")
print(result)
检索结果如下:
FaissRetriever实现
欢迎大家star正在开发的Gomate工具:https://github.com/gomate-community/GoMate
完整代码实现
https://github.com/gomate-community/GoMate/blob/main/gomate/modules/retrieval/faiss_retriever.py
import json
import os
import random
from concurrent.futures import ProcessPoolExecutor
from typing import List, Any
import faiss
import numpy as np
import tiktoken
from tqdm import tqdm
from gomate.modules.retrieval.embedding import BaseEmbeddingModel, OpenAIEmbeddingModel
from gomate.modules.retrieval.embedding import SBertEmbeddingModel
from gomate.modules.retrieval.retrievers import BaseRetriever
from gomate.modules.retrieval.utils import split_text
class FaissRetrieverConfig:
def __init__(
self,
max_tokens=100,
max_context_tokens=3500,
use_top_k=True,
embedding_model=None,
question_embedding_model=None,
top_k=5,
tokenizer=None,
embedding_model_string=None,
index_path=None,
rebuild_index=True
):
if max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
if top_k < 1:
raise ValueError("top_k must be at least 1")
if max_context_tokens is not None and max_context_tokens < 1:
raise ValueError("max_context_tokens must be at least 1 or None")
if embedding_model is not None and not isinstance(
embedding_model, BaseEmbeddingModel
):
raise ValueError(
"embedding_model must be an instance of BaseEmbeddingModel or None"
)
if question_embedding_model is not None and not isinstance(
question_embedding_model, BaseEmbeddingModel
):
raise ValueError(
"question_embedding_model must be an instance of BaseEmbeddingModel or None"
)
self.top_k = top_k
self.max_tokens = max_tokens
self.max_context_tokens = max_context_tokens
self.use_top_k = use_top_k
self.embedding_model = embedding_model or OpenAIEmbeddingModel()
self.question_embedding_model = question_embedding_model or self.embedding_model
self.tokenizer = tokenizer
self.embedding_model_string = embedding_model_string or "OpenAI"
self.index_path = index_path
self.rebuild_index=rebuild_index
def log_config(self):
config_summary = """
FaissRetrieverConfig:
Max Tokens: {max_tokens}
Max Context Tokens: {max_context_tokens}
Use Top K: {use_top_k}
Embedding Model: {embedding_model}
Question Embedding Model: {question_embedding_model}
Top K: {top_k}
Tokenizer: {tokenizer}
Embedding Model String: {embedding_model_string}
Index Path: {index_path}
Rebuild Index Path: {rebuild_index}
""".format(
max_tokens=self.max_tokens,
max_context_tokens=self.max_context_tokens,
use_top_k=self.use_top_k,
embedding_model=self.embedding_model,
question_embedding_model=self.question_embedding_model,
top_k=self.top_k,
tokenizer=self.tokenizer,
embedding_model_string=self.embedding_model_string,
index_path=self.index_path,
rebuild_index=self.rebuild_index
)
return config_summary
class FaissRetriever(BaseRetriever):
"""
FaissRetriever is a class that retrieves similar context chunks for a given query using Faiss.
encoders_type is 'same' if the question and context encoder is the same,
otherwise, encoders_type is 'different'.
"""
def __init__(self, config):
self.embedding_model = config.embedding_model
self.question_embedding_model = config.question_embedding_model
self.index = None
self.context_chunks = []
self.max_tokens = config.max_tokens
self.max_context_tokens = config.max_context_tokens
self.use_top_k = config.use_top_k
self.tokenizer = config.tokenizer
self.top_k = config.top_k
self.embedding_model_string = config.embedding_model_string
self.index_path = config.index_path
self.rebuild_index=config.rebuild_index
# Load the index from the specified path if it is not None
if not self.rebuild_index:
if self.index_path and os.path.exists(self.index_path):
self.load_index(self.index_path)
else:
os.remove(self.index_path)
def load_index(self, index_path):
"""
Loads a Faiss index from a specified path.
:param index_path: Path to the Faiss index file.
"""
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
print("Index loaded successfully.")
else:
print("Index path does not exist.")
def encode_document(self, doc_text):
"""
Builds the index from a given text.
:param doc_text: A string containing the document text.
"""
# Split the text into context chunks
context_chunks = np.array(
split_text(doc_text, self.tokenizer, self.max_tokens)
)
# Collect embeddings using a for loop
embeddings = []
for context_chunk in context_chunks:
embedding = self.embedding_model.create_embedding(context_chunk)
embeddings.append(embedding)
embeddings = np.array(embeddings, dtype=np.float32)
return embeddings,context_chunks.tolist()
def build_from_texts(self, documents):
"""
Processes multiple documents in batches, builds the index, and saves it to disk.
:param documents: List of document texts to process.
:param save_path: Path to save the index file.
:param batch_size: Number of documents to process in each batch.
"""
self.all_embeddings = []
self.context_chunks=[]
for i in tqdm(range(0, len(documents))):
doc_embeddings,context_chunks = self.encode_document(documents[i])
self.all_embeddings.append(doc_embeddings)
self.context_chunks.extend(context_chunks)
# Initialize the index only once
if self.index is None and self.all_embeddings:
self.index = faiss.IndexFlatIP(self.all_embeddings[0].shape[1])
self.all_embeddings = np.vstack(self.all_embeddings)
print(self.all_embeddings.shape)
print(len(self.context_chunks))
self.index.add(self.all_embeddings)
# Save the index to disk
faiss.write_index(self.index, self.index_path)
def sanity_check(self, num_samples=4):
"""
Perform a sanity check by recomputing embeddings of a few randomly-selected chunks.
:param num_samples: The number of samples to test.
"""
indices = random.sample(range(len(self.context_chunks)), num_samples)
for i in indices:
original_embedding = self.all_embeddings[i]
recomputed_embedding = self.embedding_model.create_embedding(
self.context_chunks[i]
)
assert np.allclose(
original_embedding, recomputed_embedding
), f"Embeddings do not match for index {i}!"
print(f"Sanity check passed for {num_samples} random samples.")
def retrieve(self, query: str) -> list[Any]:
"""
Retrieves the k most similar context chunks for a given query.
:param query: A string containing the query.
:param k: An integer representing the number of similar context chunks to retrieve.
:return: A string containing the retrieved context chunks.
"""
query_embedding = np.array(
[
np.array(
self.question_embedding_model.create_embedding(query),
dtype=np.float32,
).squeeze()
]
)
context = []
if self.use_top_k:
distances, indices = self.index.search(query_embedding, self.top_k)
print(distances,indices)
print(distances[0][2],indices)
for i in range(self.top_k):
context.append({'text':self.context_chunks[indices[0][i]],'score':distances[0][i]})
else:
range_ = int(self.max_context_tokens / self.max_tokens)
_, indices = self.index.search(query_embedding, range_)
total_tokens = 0
for i in range(range_):
tokens = len(self.tokenizer.encode(self.context_chunks[indices[0][i]]))
context.append(self.context_chunks[indices[0][i]])
if total_tokens + tokens > self.max_context_tokens:
break
total_tokens += tokens
return context
if __name__ == '__main__':
from transformers import AutoTokenizer
embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5"
embedding_model = SBertEmbeddingModel(embedding_model_path)
tokenizer = AutoTokenizer.from_pretrained(embedding_model_path)
retriever_config = FaissRetrieverConfig(
max_tokens=100,
max_context_tokens=3500,
use_top_k=True,
embedding_model=embedding_model,
top_k=5,
tokenizer=tokenizer,
embedding_model_string="bge-large-zh-v1.5",
index_path="faiss_index.bin",
rebuild_index=True
)
faiss_retriever=FaissRetriever(config=retriever_config)
documents=[]
with open('/home/test/codes/GoMate/data/zh_refine.json','r',encoding="utf-8") as f:
for line in f.readlines():
data=json.loads(line)
documents.extend(data['positive'])
documents.extend(data['negative'])
print(len(documents))
faiss_retriever.build_from_texts(documents[:200])
contexts=faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁")
print(contexts)
参考资料
更多推荐
所有评论(0)