参考文献生成项目,使用faiss实现
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

236 lines
6.8 KiB

import os
import random
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from flask import Flask, jsonify
from flask import request
import numpy as np
import faiss
import json
import requests
import socket
from sentence_transformers import SentenceTransformer
with open("data/lable/id2lable.json", encoding="utf-8") as f:
id2lable = json.loads(f.read())
with open("data/lable/lable2id.json", encoding="utf-8") as f:
lable2id = json.loads(f.read())
with open("data/discipline_types.json") as f:
lable_discipline_types = json.loads(f.read())
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
d = 768 # dimension
model = SentenceTransformer('Dmeta-embedding-zh')
def get_host_ip():
"""
查询本机ip地址
:return: ip
"""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
finally:
s.close()
return ip
# url = "http://{}:50003/roformer".format(str(get_host_ip()))
url = "http://{}:50003/roformer".format("192.168.31.149")
def dialog_line_parse(url, text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
response = requests.post(
url,
json=text,
timeout=1000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text))
print(text)
return []
def panduan_paper_lable(paper_lable_text):
paper_lable = {
"硕士": "D",
"期刊": "J",
"博士": "J"
}
return paper_lable[paper_lable_text]
def ulit_recall_paper(reference_list, nums):
'''
对返回的十篇文章路径读取并解析
:param recall_data_list_path:
:return data: list [[sentence, filename],[sentence, filename],[sentence, filename]]
'''
# data = []
# for path in recall_data_list_path:
# filename = path.split("/")[-1]
# with open(path, encoding="gbk") as f:
# text = f.read()
# text_list = text.split("\n")
# for sentence in text_list:
# if sentence != "":
# data.append([sentence, filename])
# return data
# recall_data_list
# 作者 论文名称 论文类别 论文来源 论文年份 摘要 期刊
# "[1]赵璐.基于旅游资源开发下的新农村景观营建研究[D].西安建筑科技大学,2014."
data_info = []
data_title = []
for data_one in reference_list:
if data_one[1] not in data_title:
print("data_one", data_one)
print("data_one[0]", data_one[0])
paper = ".".join([
",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]),
data_one[1] + f"[{panduan_paper_lable(data_one[6])}]",
",".join([
data_one[3], str(data_one[4]) + "."
])
])
data_title.append(data_one[1])
data_info.append({
"author": data_one[0],
"title": data_one[1],
"special_topic": data_one[2],
"qikan_name": data_one[3],
"year": str(data_one[4]),
"abstract": data_one[5],
"classlable": data_one[6],
"reference": paper
})
# print(data)
print(data_title)
print(nums)
random.shuffle(data_info)
random.shuffle(data_info)
data_info = data_info[:int(nums)]
return data_info
def main(title, abstract, nums):
data = {
"title": title,
"abst_zh": abstract,
"content": ""
}
# {
# "label_num": [
# 117,
# 143
# ]
# }
result = dialog_line_parse(url, data)
# print(result['label_num'][0])
# print(id2lable[result['label_num'][0]])
subject_pinyin = lable_discipline_types[id2lable[str(result['label_num'][0])]]
# with open(f"data/prompt/{subject_pinyin}.npy") as :
# zidonghua = np.load('data/prompt/{subject_pinyin}.npy')
data_subject = np.load(f"data/prompt_qikan/{subject_pinyin}.npy")
data_subject_1 = np.load(f"data/prompt_master/{subject_pinyin}.npy")
data_subject_2 = np.load(f"data/prompt_doctor/{subject_pinyin}.npy")
print("xb.shape", data_subject.shape)
print("xb_1.shape", data_subject_1.shape)
print("xb_2.shape", data_subject_2.shape)
data_subject = np.concatenate((data_subject, data_subject_1, data_subject_2))
print("data_subject.shape", data_subject.shape)
index = faiss.read_index(f'data/prompt_qikan_master_doctor_ivf/{subject_pinyin}.ivf')
with open(f"data/data_info_qikan/{subject_pinyin}.json") as f:
data_info = json.loads(f.read())
with open(f"data/data_info_master/{subject_pinyin}.json") as f:
data_info_1 = json.loads(f.read())
with open(f"data/data_info_doctor/{subject_pinyin}.json") as f:
data_info_2 = json.loads(f.read())
print(len(data_info))
print(len(data_info_1))
print(len(data_info_2))
data_info = data_info + data_info_1 + data_info_2
print(len(data_info))
print(data_info[0])
index.add(data_subject)
# index.nprobe = 2 # default nprobe is 1, try a few more
# k = nums
k = 20
prompt = "标题:“{}”,摘要:“{}".format(title, abstract)
embs = model.encode([prompt], normalize_embeddings=True)
D, I = index.search(embs, int(k))
# print(I)
reference_list = []
for i in I[0]:
reference_list.append(data_info[i])
data_info = ulit_recall_paper(reference_list, nums)
return "200", data_info
@app.route("/", methods=["POST"])
def handle_query():
# try:
title = request.form.get("title")
abstract = ""
nums = request.form.get('nums')
# content = ulit_request_file(file)
status_code, data_info_list = main(title, abstract, nums)
if status_code == "400":
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
else:
if status_code == "200":
return_text = {
"data_info": data_info_list,
"probabilities": None,
"status_code": 200
}
else:
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
return jsonify(return_text) # 返回结果
if __name__ == "__main__":
app.run(host="0.0.0.0", port=17003, threaded=True)