以图搜图功能实现(ES/Milvus)
将非结构化数据→转为结构化→再完成搜索
·
思路方案
思路
需要将非结构化数据→转为结构化→再完成搜索。将非结构化数据,转化为结构化的多维向量,用这些向量标识实体和实体间的关系。再计算向量之间距离,通常情况下,距离越近、相似度越高,召回相似度最高的TOP结果,完成检索。
方案
给定一组查询图片和数据库图片。我们对数据库图片执行以图搜图操作,在image embeddings(将图片数据转换为固定大小的特征表示——矢量)上获取前k个最相似的数据库中的图片。
将采用以下两种方法执行以图搜图功能:
Milvus 向量数据库
Milvus 在非结构化数据处理中的应用非常强大。Milvus 向量相似度检索引擎可以兼容各种深度学习平台,搜索十亿向量仅毫秒响应。
ElasticSearch 向量数据库(resnet50模型)
方案一(Milvus1.0)
功能介绍
以图搜图,涉及两大功能:1、提取图像特征向量。2、相似向量检索。
通过计算特征向量来分析非结构化数据。使用ResNet-50进行特征提取,构建反向图像搜索系统。
环境搭建
yaml文件配置
version: 0.5
cluster:
enable: false
role: rw
general:
timezone: UTC+8
meta_uri: sqlite://:@:/
network:
bind.address: 0.0.0.0
bind.port: 19530
http.enable: true
http.port: 19121
storage:
path: /var/lib/milvus
auto_flush_interval: 1
wal:
enable: true
recovery_error_ignore: false
buffer_size: 256MB
path: /var/lib/milvus/wal
cache:
cache_size: 256MB
insert_buffer_size: 256MB
preload_collection:
gpu:
enable: false
cache_size: 256MB
gpu_search_threshold: 1000
search_devices:
- gpu0
build_index_devices:
- gpu0
fpga:
enable: false
search_devices:
- fpga0
logs:
level: debug
trace.enable: true
path: /var/lib/milvus/logs
max_log_file_size: 1073741824
log_rotate_num: 0
log_to_stdout: false
log_to_file: true
metric:
enable: false
address: 127.0.0.1
port: 9091
docker部署
docker run -d --name milvus_1 \
-p 19530:19530 \
-p 19121:19121 \
-v /root/milvus/db:/var/lib/milvus/db \
-v /root/milvus/conf:/var/lib/milvus/conf \
-v /root/milvus/logs:/var/lib/milvus/logs \
-v /root/milvus/wal:/var/lib/milvus/wal \
milvusdb/milvus:1.0.0-cpu-d030521-1ea92e
2)
docker run -d --name image_search \
-v /root/milvus/pic:/tmp/pic1 \
-p 35000:5000 \
-e "DATA_PATH=/tmp/images-data" \
-e "MILVUS_HOST=你的服务器ip地址" \
milvusbootcamp/pic-search-webserver:1.0
3)
docker run --name milvus_image_search_web -d --rm -p 8001:80 \
-e API_URL=http://你的服务器ip地址:35000 \
milvusbootcamp/pic-search-webclient:1.0
效果图
测试
原图进行验证搜索
截图进行验证搜索
不相关图片进行验证
升级方案(Milvus2.X)
问题
在调用模型时无法连接至hugging face 无法将图片转为向量
方案二(ElasticSearch + ResNet-50模型)
功能介绍
以图搜图,涉及两大功能:1、提取图像特征向量。2、相似向量检索。
第一个功能通过pytorch下载保存resnet50模型并在java端借助djl调用实现,第二个功能通过elasticsearch7.12.2的dense_vector、cosineSimilarity实现。
环境部署(通过编写pytorch模型并在java端借助djl调用实现)
提取图像特征下载模型到本地(resnet50模型)
import torch
import torch.nn as nn
import torchvision.models as models
class ImageFeatureExtractor(nn.Module):
def __init__(self):
super(ImageFeatureExtractor, self).__init__()
self.resnet = models.resnet50(pretrained=True)
#最终输出维度1024的向量,下文elastic search要设置dims为1024
self.resnet.fc = nn.Linear(2048, 1024)
def forward(self, x):
x = self.resnet(x)
return x
if __name__ == '__main__':
model = ImageFeatureExtractor()
model.eval()
#根据模型随便创建一个输入
input = torch.rand([1, 3, 224, 224])
output = model(input)
#以这种方式保存
script = torch.jit.trace(model, input)
script.save("model.pt")
保存好的model.pt文件放入java项目的resources中。
部署elasticsearch kibana
es版本:7.6.2
docker部署:
docker run -p 9200:9200 -p 9300:9300 \
--privileged=true --name es7.6.2 \
-e "discovery.type=single-node" \
-e ES_JAVA_OPTS="-Xms512m -Xmx1024m" \
-e "http.max_content_length=500mb" \
-v /root/mydata/plugins:/usr/share/elasticsearch/plugins \
-v /root/mydata/data:/usr/share/elasticsearch/data \
-v /root/mydata/logs:/usr/share/elasticsearch/logs \
-d elasticsearch:7.6.2
docker run -d \
--name kibana \
--restart=always \
-p 5601:5601 \
-v /data/kibana/config/kibana.yml:/usr/share/kibana/config/kibana.yml \
kibana:7.6.2
创建索引库
PUT /isi
{
"mappings": {
"properties": {
"vector": {
"type": "dense_vector",
"dims": 1024
},
"url" : {
"type" : "keyword"
},
"user_id": {
"type": "keyword"
}
}
}
}
相似向量上传、检索
创建调用resnet模型 转化格式
public class Test {
private static final String INDEX = "isi";
private static final int IMAGE_SIZE = 224;
private static Model model; //模型
private static Predictor<Image, float[]> predictor;
//predictor.predict(input)相当于python中model(input)
static {
try {
model = Model.newInstance("model");
//这里的model.pt是上面代码展示的那种方式保存的
model.load(Test.class.getClassLoader().getResourceAsStream("model.pt"));
Transform resize = new Resize(IMAGE_SIZE);
Transform toTensor = new ToTensor();
Transform normalize = new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f});
//Translator处理输入Image转为tensor、输出转为float[]
Translator<Image, float[]> translator = new Translator<Image, float[]>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
NDManager ndManager = ctx.getNDManager();
System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
NDArray transform = normalize.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
System.out.println(transform.getShape());
NDList list = new NDList();
list.add(transform);
return list;
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
return ndList.get(0).toFloatArray();
}
};
predictor = new Predictor<>(model, translator, Device.cpu(), true);
} catch (Exception e) {
e.printStackTrace();
}
}
}
批量上传图片到es
public static void upload() throws Exception {
RestHighLevelClient client = new RestHighLevelClient(
RestClient.builder(new HttpHost("192.168.110.132", 9200, "http")));
//批量上传请求
File file = new File("E:\\javacode\\javaes\\src\\main\\resources\\test");
File[] files = file.listFiles();
if (files == null) return;
int batchSize = 1000;
for (int i = 0; i < files.length; i += batchSize) {
BulkRequest bulkRequest = new BulkRequest(INDEX);
for (int j = i; j < i + batchSize && j < files.length; j++) {
File listFile = files[j];
float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(new FileInputStream(listFile)));
Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("url", listFile.getAbsolutePath());
jsonMap.put("vector", vector);
jsonMap.put("user_id", "user123");
IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
bulkRequest.add(request);
}
client.bulk(bulkRequest, RequestOptions.DEFAULT);
/*for (File listFile : file.listFiles()) {
float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(Test2.class.getClassLoader().getResourceAsStream("test/" + listFile.getName())));
// 构建文档
Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("url", listFile.getAbsolutePath());
jsonMap.put("vector", vector);
jsonMap.put("user_id", "user123");
IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
bulkRequest.add(request);*/
}
client.close();
}
搜索(将图片转为向量与es文档库匹配)
public static List<SearchResult> search(InputStream input) throws Throwable {
float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
System.out.println(Arrays.toString(vector));
//展示k个结果
int k = 50;
// 连接Elasticsearch服务器
RestHighLevelClient client = new RestHighLevelClient(
RestClient.builder(new HttpHost("192.168.110.132", 9200, "http")));
SearchRequest searchRequest = new SearchRequest(INDEX);
Script script = new Script(
ScriptType.INLINE,
"painless",
"cosineSimilarity(params.queryVector, doc['vector'])",
Collections.singletonMap("queryVector", vector));
FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders.functionScoreQuery(
QueryBuilders.matchAllQuery(),
ScoreFunctionBuilders.scriptFunction(script));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(functionScoreQueryBuilder)
.fetchSource(null, "vector") //不返回vector字段,没用还耗时
.size(k);
searchRequest.source(searchSourceBuilder);
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
SearchHits hits = searchResponse.getHits();
List<SearchResult> list = new ArrayList<>();
for (SearchHit hit : hits) {
// 处理搜索结果
System.out.println(hit.toString());
SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
list.add(result);
}
client.close();
return list;
}
效果图
测试
原图进行验证搜索
截图进行验证搜索
不相关图片进行验证
更多推荐
已为社区贡献1条内容
所有评论(0)