SapBERT: Self-alignment pretraining for BERT的代码使用示例
【代码】SapBERT: Self-alignment pretraining for BERT的代码使用示例。
·
最近在研究SapBERT来计算实体的相似度,发现官方的repo没有给使用示例,我仿照写了一下使用示例,方便直接把SapBERT用起来,我的环境是:
torch 1.7.1+cu101
torchvision 0.11.3
transformers 4.16.2
下面是使用代码,知道SapBERT是抽取向量的就行了,然后就可以用一些类似faiss的近似向量检索工具进行检索了:
from transformers import AutoTokenizer, AutoModel
import numpy as np
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
query = "cardiopathy"
query_toks = tokenizer.batch_encode_plus([query],
padding="max_length",
max_length=25,
truncation=True,
return_tensors="pt")
print(query_toks)
query_output = model(**query_toks)
query_cls_rep = query_output[0][:,0,:]
print(query_cls_rep)
all_names = ['Neoplasm of anterior aspect of epiglottis']
toks = tokenizer.batch_encode_plus(all_names,
padding="max_length",
max_length=25,
truncation=True,
return_tensors="pt")
output = model(**toks)
cls_rep = output[0][:,0,:]
print(cls_rep)
# for large-scale search, should switch to faiss
from scipy.spatial.distance import cdist
dist = cdist(query_cls_rep.cpu().detach().numpy(), cls_rep.cpu().detach().numpy())
nn_index = np.argmin(dist)
# print ("predicted label:", snomed_sf_id_pairs_100k[nn_index])
更多推荐
已为社区贡献9条内容
所有评论(0)