朋友们,simbert模型是一个较好的相似句检索模型,但是在大规模检索中,需要实现快速检索,这个时候离不开milvus等向量检索库,下面用实际代码来讲一下simbert之milvus应用。
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
import tensorflow as tf
from openapi_server.models.sentence_schema import SentenceSchema
from openapi_server.models.QaVecSchema import QaVecSchema
import connexion
from mysql_tool.connection import DBHelper
from config.loadconfig import get_logger
from milvus import Milvus, IndexType, MetricType, Status
import random
from bert4keras.snippets import sequence_padding
from apscheduler.schedulers.background import BackgroundScheduler
import datetime
import os
logger = get_logger(__name__)
global graph
graph = tf.get_default_graph()
sess = keras.backend.get_session()
# 获取绝对目录上上级目录
upper2path = os.path.abspath(os.path.join(os.getcwd()))
# bert配置
config_path = "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/bert_config.json"
checkpoint_path = "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/bert_model.ckpt"
dict_path = "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/vocab.txt"
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
# 建立加载模型
bert = build_transformer_model(config_path,checkpoint_path,with_pool='linear',application='unilm',return_keras_model=False,
)# 加载编码器
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
向量入库:
def qa2vecs():collection_reconstruct()data = qa_query()milvus, collection_name = MilvusHelper().connection()param = {'collection_name': collection_name,'dimension': 384,'index_file_size': 256, # optional'metric_type': MetricType.IP # optional}milvus.create_collection(param)vecs = []ids = []progress_idx = 0with sess.as_default():with graph.as_default():for record in data:progress_idx += 1token_ids, segment_ids = tokenizer.encode(record["text"])vec = encoder.predict([[token_ids], [segment_ids]])[0]vecs.append(vec)ids.append(record["id"])if (len(ids) % 5000 == 0 or progress_idx == len(data)) and len(ids) > 0:logger.info("data sync :{:.2f}%".format(progress_idx * 100.0 / len(data)))milvus.insert(collection_name=collection_name, records=vecs_normalize(vecs), ids=ids, params=param)vecs = []ids = []milvus.close()return progress_idx
上面的向量入库的时候,文本的id和text都存了,milvus里面有id->text的向量,所以最终检索的时候,能够同时拿到vector和id,然后id去mysql里面找即可。