最近在研究SapBERT来计算实体的相似度,发现官方的repo没有给使用示例,我仿照写了一下使用示例,方便直接把SapBERT用起来,我的环境是:

torch                   1.7.1+cu101
torchvision             0.11.3
transformers            4.16.2

下面是使用代码,知道SapBERT是抽取向量的就行了,然后就可以用一些类似faiss的近似向量检索工具进行检索了:

from transformers import AutoTokenizer, AutoModel
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

query = "cardiopathy"
query_toks = tokenizer.batch_encode_plus([query], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")
print(query_toks)
query_output = model(**query_toks)
query_cls_rep = query_output[0][:,0,:]
print(query_cls_rep)

all_names = ['Neoplasm of anterior aspect of epiglottis']

toks = tokenizer.batch_encode_plus(all_names, 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")

output = model(**toks)
cls_rep = output[0][:,0,:]
print(cls_rep)

# for large-scale search, should switch to faiss
from scipy.spatial.distance import cdist

dist = cdist(query_cls_rep.cpu().detach().numpy(), cls_rep.cpu().detach().numpy())
nn_index = np.argmin(dist)
# print ("predicted label:", snomed_sf_id_pairs_100k[nn_index])

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐