参考文献生成项目,使用faiss实现
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
7.0 KiB

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from flask import Flask, jsonify
from flask import request
import numpy as np
import faiss
import json
import requests
import socket
from sentence_transformers import SentenceTransformer
with open("data/lable/id2lable.json", encoding="utf-8") as f:
id2lable = json.loads(f.read())
with open("data/lable/lable2id.json", encoding="utf-8") as f:
lable2id = json.loads(f.read())
with open("data/discipline_types.json") as f:
lable_discipline_types = json.loads(f.read())
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
d = 768 # dimension
model = SentenceTransformer('Dmeta-embedding-zh')
def get_host_ip():
"""
查询本机ip地址
:return: ip
"""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
finally:
s.close()
return ip
# url = "http://{}:50003/roformer".format(str(get_host_ip()))
url = "http://{}:50003/roformer".format("192.168.31.149")
def dialog_line_parse(url, text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
response = requests.post(
url,
json=text,
timeout=1000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text))
print(text)
return []
def panduan_paper_lable(paper_lable_text):
paper_lable = {
"硕士": "D",
"期刊": "J",
"博士": "D"
}
return paper_lable[paper_lable_text]
def ulit_recall_paper(reference_list, nums):
'''
对返回的十篇文章路径读取并解析
:param recall_data_list_path:
:return data: list [[sentence, filename],[sentence, filename],[sentence, filename]]
'''
# data = []
# for path in recall_data_list_path:
# filename = path.split("/")[-1]
# with open(path, encoding="gbk") as f:
# text = f.read()
# text_list = text.split("\n")
# for sentence in text_list:
# if sentence != "":
# data.append([sentence, filename])
# return data
# recall_data_list
# 作者 论文名称 论文类别 论文来源 论文年份 摘要 期刊
# "[1]赵璐.基于旅游资源开发下的新农村景观营建研究[D].西安建筑科技大学,2014."
data = []
for data_one in reference_list:
print("data_one", data_one)
print("data_one[0]", data_one[0])
if panduan_paper_lable(data_one[6]) == "J":
paper = ".".join([
",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]),
data_one[1] + f"[{panduan_paper_lable(data_one[6])}]",
",".join([
data_one[3], str(data_one[4])
])
]) + "" + f"({data_one[8]})" + f":{data_one[7]}" + "."
else:
paper = ".".join([
",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]),
data_one[1] + f"[{panduan_paper_lable(data_one[6])}]",
",".join([
data_one[3], str(data_one[4]) + "."
]),
])
data.append(paper)
# print(data)
data = list(set(data))
print(len(data))
print(data[0])
print(nums)
data = data[:int(nums)]
return data
def main(title, abstract, nums):
data = {
"title": title,
"abst_zh": abstract,
"content": ""
}
# {
# "label_num": [
# 117,
# 143
# ]
# }
result = dialog_line_parse(url, data)
# print(result['label_num'][0])
# print(id2lable[result['label_num'][0]])
subject_pinyin = lable_discipline_types[id2lable[str(result['label_num'][0])]]
# with open(f"data/prompt/{subject_pinyin}.npy") as :
# zidonghua = np.load('data/prompt/{subject_pinyin}.npy')
data_subject = np.load(f"data/prompt_qikan/{subject_pinyin}.npy")
data_subject_1 = np.load(f"data/prompt_master/{subject_pinyin}.npy")
data_subject_2 = np.load(f"data/prompt_doctor/{subject_pinyin}.npy")
print("xb.shape", data_subject.shape)
print("xb_1.shape", data_subject_1.shape)
print("xb_2.shape", data_subject_2.shape)
data_subject = np.concatenate((data_subject, data_subject_1, data_subject_2))
print("data_subject.shape", data_subject.shape)
index = faiss.read_index(f'data/prompt_qikan_master_doctor_ivf/{subject_pinyin}.ivf')
with open(f"data/data_info_qikan_1/{subject_pinyin}.json") as f:
data_info = json.loads(f.read())
with open(f"data/data_info_master/{subject_pinyin}.json") as f:
data_info_1 = json.loads(f.read())
with open(f"data/data_info_doctor/{subject_pinyin}.json") as f:
data_info_2 = json.loads(f.read())
print(len(data_info))
print(len(data_info_1))
print(len(data_info_2))
data_info = data_info + data_info_1 + data_info_2
print(len(data_info))
print(data_info[0])
index.add(data_subject)
# index.nprobe = 2 # default nprobe is 1, try a few more
# k = nums
k = 50
prompt = "标题:“{}”,摘要:“{}".format(title, abstract)
embs = model.encode([prompt], normalize_embeddings=True)
D, I = index.search(embs, int(k))
print(I)
reference_list = []
abstract_list = []
for i in I[0]:
reference_list.append(data_info[i])
abstract_list.append(data_info[i][5])
return "200", ulit_recall_paper(reference_list, nums), abstract_list
@app.route("/", methods=["POST"])
def handle_query():
# try:
title = request.form.get("title")
abstract = request.form.get('abstract')
nums = request.form.get('nums')
# content = ulit_request_file(file)
status_code, reference, abstract_list = main(title, abstract, nums)
if status_code == "400":
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
else:
reference_list = reference
print(reference_list)
reference = [f"[{str(i+1)}]" + reference_list[i] for i in range(len(reference_list))]
if status_code == "200":
return_text = {
"resilt": reference,
"probabilities": None,
"status_code": 200
}
else:
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
# except:
# return_text = {"resilt": "", "probabilities": None, "status_code": 400}
return jsonify(return_text) # 返回结果
if __name__ == "__main__":
app.run(host="0.0.0.0", port=17001, threaded=True)