You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
7.0 KiB
234 lines
7.0 KiB
import os
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
from flask import Flask, jsonify
|
|
from flask import request
|
|
import numpy as np
|
|
import faiss
|
|
import json
|
|
import requests
|
|
import socket
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
with open("data/lable/id2lable.json", encoding="utf-8") as f:
|
|
id2lable = json.loads(f.read())
|
|
|
|
with open("data/lable/lable2id.json", encoding="utf-8") as f:
|
|
lable2id = json.loads(f.read())
|
|
|
|
with open("data/discipline_types.json") as f:
|
|
lable_discipline_types = json.loads(f.read())
|
|
|
|
|
|
app = Flask(__name__)
|
|
app.config["JSON_AS_ASCII"] = False
|
|
|
|
d = 768 # dimension
|
|
model = SentenceTransformer('Dmeta-embedding-zh')
|
|
|
|
def get_host_ip():
|
|
"""
|
|
查询本机ip地址
|
|
:return: ip
|
|
"""
|
|
try:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
s.connect(('8.8.8.8', 80))
|
|
ip = s.getsockname()[0]
|
|
finally:
|
|
s.close()
|
|
|
|
return ip
|
|
|
|
# url = "http://{}:50003/roformer".format(str(get_host_ip()))
|
|
url = "http://{}:50003/roformer".format("192.168.31.149")
|
|
|
|
def dialog_line_parse(url, text):
|
|
"""
|
|
将数据输入模型进行分析并输出结果
|
|
:param url: 模型url
|
|
:param text: 进入模型的数据
|
|
:return: 模型返回结果
|
|
"""
|
|
|
|
response = requests.post(
|
|
url,
|
|
json=text,
|
|
timeout=1000
|
|
)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
# logger.error(
|
|
# "【{}】 Failed to get a proper response from remote "
|
|
# "server. Status Code: {}. Response: {}"
|
|
# "".format(url, response.status_code, response.text)
|
|
# )
|
|
print("【{}】 Failed to get a proper response from remote "
|
|
"server. Status Code: {}. Response: {}"
|
|
"".format(url, response.status_code, response.text))
|
|
print(text)
|
|
return []
|
|
|
|
|
|
def panduan_paper_lable(paper_lable_text):
|
|
paper_lable = {
|
|
"硕士": "D",
|
|
"期刊": "J",
|
|
"博士": "D"
|
|
}
|
|
return paper_lable[paper_lable_text]
|
|
|
|
|
|
def ulit_recall_paper(reference_list, nums):
|
|
'''
|
|
对返回的十篇文章路径读取并解析
|
|
:param recall_data_list_path:
|
|
:return data: list [[sentence, filename],[sentence, filename],[sentence, filename]]
|
|
'''
|
|
|
|
# data = []
|
|
# for path in recall_data_list_path:
|
|
# filename = path.split("/")[-1]
|
|
# with open(path, encoding="gbk") as f:
|
|
# text = f.read()
|
|
# text_list = text.split("\n")
|
|
# for sentence in text_list:
|
|
# if sentence != "":
|
|
# data.append([sentence, filename])
|
|
# return data
|
|
|
|
# recall_data_list
|
|
# 作者 论文名称 论文类别 论文来源 论文年份 摘要 期刊
|
|
# "[1]赵璐.基于旅游资源开发下的新农村景观营建研究[D].西安建筑科技大学,2014."
|
|
|
|
data = []
|
|
for data_one in reference_list:
|
|
print("data_one", data_one)
|
|
print("data_one[0]", data_one[0])
|
|
|
|
if panduan_paper_lable(data_one[6]) == "J":
|
|
paper = ".".join([
|
|
",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]),
|
|
data_one[1] + f"[{panduan_paper_lable(data_one[6])}]",
|
|
",".join([
|
|
data_one[3], str(data_one[4])
|
|
])
|
|
]) + "," + f"({data_one[8]})" + f":{data_one[7]}" + "."
|
|
else:
|
|
paper = ".".join([
|
|
",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]),
|
|
data_one[1] + f"[{panduan_paper_lable(data_one[6])}]",
|
|
",".join([
|
|
data_one[3], str(data_one[4]) + "."
|
|
]),
|
|
])
|
|
|
|
data.append(paper)
|
|
|
|
# print(data)
|
|
data = list(set(data))
|
|
print(len(data))
|
|
print(data[0])
|
|
print(nums)
|
|
data = data[:int(nums)]
|
|
return data
|
|
|
|
|
|
def main(title, abstract, nums):
|
|
data = {
|
|
"title": title,
|
|
"abst_zh": abstract,
|
|
"content": ""
|
|
}
|
|
# {
|
|
# "label_num": [
|
|
# 117,
|
|
# 143
|
|
# ]
|
|
# }
|
|
result = dialog_line_parse(url, data)
|
|
|
|
# print(result['label_num'][0])
|
|
# print(id2lable[result['label_num'][0]])
|
|
subject_pinyin = lable_discipline_types[id2lable[str(result['label_num'][0])]]
|
|
|
|
# with open(f"data/prompt/{subject_pinyin}.npy") as :
|
|
# zidonghua = np.load('data/prompt/{subject_pinyin}.npy')
|
|
|
|
data_subject = np.load(f"data/prompt_qikan/{subject_pinyin}.npy")
|
|
data_subject_1 = np.load(f"data/prompt_master/{subject_pinyin}.npy")
|
|
data_subject_2 = np.load(f"data/prompt_doctor/{subject_pinyin}.npy")
|
|
print("xb.shape", data_subject.shape)
|
|
print("xb_1.shape", data_subject_1.shape)
|
|
print("xb_2.shape", data_subject_2.shape)
|
|
data_subject = np.concatenate((data_subject, data_subject_1, data_subject_2))
|
|
print("data_subject.shape", data_subject.shape)
|
|
|
|
index = faiss.read_index(f'data/prompt_qikan_master_doctor_ivf/{subject_pinyin}.ivf')
|
|
|
|
with open(f"data/data_info_qikan_1/{subject_pinyin}.json") as f:
|
|
data_info = json.loads(f.read())
|
|
|
|
with open(f"data/data_info_master/{subject_pinyin}.json") as f:
|
|
data_info_1 = json.loads(f.read())
|
|
|
|
with open(f"data/data_info_doctor/{subject_pinyin}.json") as f:
|
|
data_info_2 = json.loads(f.read())
|
|
|
|
print(len(data_info))
|
|
print(len(data_info_1))
|
|
print(len(data_info_2))
|
|
data_info = data_info + data_info_1 + data_info_2
|
|
print(len(data_info))
|
|
print(data_info[0])
|
|
index.add(data_subject)
|
|
# index.nprobe = 2 # default nprobe is 1, try a few more
|
|
# k = nums
|
|
k = 50
|
|
prompt = "标题:“{}”,摘要:“{}”".format(title, abstract)
|
|
embs = model.encode([prompt], normalize_embeddings=True)
|
|
|
|
D, I = index.search(embs, int(k))
|
|
print(I)
|
|
|
|
reference_list = []
|
|
abstract_list = []
|
|
for i in I[0]:
|
|
reference_list.append(data_info[i])
|
|
abstract_list.append(data_info[i][5])
|
|
|
|
return "200", ulit_recall_paper(reference_list, nums), abstract_list
|
|
|
|
|
|
@app.route("/", methods=["POST"])
|
|
def handle_query():
|
|
# try:
|
|
title = request.form.get("title")
|
|
abstract = request.form.get('abstract')
|
|
nums = request.form.get('nums')
|
|
|
|
# content = ulit_request_file(file)
|
|
|
|
status_code, reference, abstract_list = main(title, abstract, nums)
|
|
|
|
if status_code == "400":
|
|
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
|
|
else:
|
|
reference_list = reference
|
|
print(reference_list)
|
|
reference = [f"[{str(i+1)}]" + reference_list[i] for i in range(len(reference_list))]
|
|
if status_code == "200":
|
|
return_text = {
|
|
"resilt": reference,
|
|
"probabilities": None,
|
|
"status_code": 200
|
|
}
|
|
else:
|
|
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
|
|
# except:
|
|
# return_text = {"resilt": "", "probabilities": None, "status_code": 400}
|
|
return jsonify(return_text) # 返回结果
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host="0.0.0.0", port=17001, threaded=True)
|