import os import random os.environ["CUDA_VISIBLE_DEVICES"] = "0" from flask import Flask, jsonify from flask import request import numpy as np import faiss import json import requests import socket from sentence_transformers import SentenceTransformer with open("data/lable/id2lable.json", encoding="utf-8") as f: id2lable = json.loads(f.read()) with open("data/lable/lable2id.json", encoding="utf-8") as f: lable2id = json.loads(f.read()) with open("data/discipline_types.json") as f: lable_discipline_types = json.loads(f.read()) app = Flask(__name__) app.config["JSON_AS_ASCII"] = False d = 768 # dimension model = SentenceTransformer('Dmeta-embedding-zh') def get_host_ip(): """ 查询本机ip地址 :return: ip """ try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(('8.8.8.8', 80)) ip = s.getsockname()[0] finally: s.close() return ip # url = "http://{}:50003/roformer".format(str(get_host_ip())) url = "http://{}:50003/roformer".format("192.168.31.149") def dialog_line_parse(url, text): """ 将数据输入模型进行分析并输出结果 :param url: 模型url :param text: 进入模型的数据 :return: 模型返回结果 """ response = requests.post( url, json=text, timeout=1000 ) if response.status_code == 200: return response.json() else: # logger.error( # "【{}】 Failed to get a proper response from remote " # "server. Status Code: {}. Response: {}" # "".format(url, response.status_code, response.text) # ) print("【{}】 Failed to get a proper response from remote " "server. Status Code: {}. Response: {}" "".format(url, response.status_code, response.text)) print(text) return [] def panduan_paper_lable(paper_lable_text): paper_lable = { "硕士": "D", "期刊": "J", "博士": "J" } return paper_lable[paper_lable_text] def ulit_recall_paper(reference_list, nums): ''' 对返回的十篇文章路径读取并解析 :param recall_data_list_path: :return data: list [[sentence, filename],[sentence, filename],[sentence, filename]] ''' # data = [] # for path in recall_data_list_path: # filename = path.split("/")[-1] # with open(path, encoding="gbk") as f: # text = f.read() # text_list = text.split("\n") # for sentence in text_list: # if sentence != "": # data.append([sentence, filename]) # return data # recall_data_list # 作者 论文名称 论文类别 论文来源 论文年份 摘要 期刊 # "[1]赵璐.基于旅游资源开发下的新农村景观营建研究[D].西安建筑科技大学,2014." data_info = [] data_title = [] for data_one in reference_list: if data_one[1] not in data_title: print("data_one", data_one) print("data_one[0]", data_one[0]) paper = ".".join([ ",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]), data_one[1] + f"[{panduan_paper_lable(data_one[6])}]", ",".join([ data_one[3], str(data_one[4]) + "." ]) ]) data_title.append(data_one[1]) data_info.append({ "author": data_one[0], "title": data_one[1], "special_topic": data_one[2], "qikan_name": data_one[3], "year": str(data_one[4]), "abstract": data_one[5], "classlable": data_one[6], "reference": paper }) # print(data) print(data_title) print(nums) random.shuffle(data_info) random.shuffle(data_info) data_info = data_info[:int(nums)] return data_info def main(title, abstract, nums): data = { "title": title, "abst_zh": abstract, "content": "" } # { # "label_num": [ # 117, # 143 # ] # } result = dialog_line_parse(url, data) # print(result['label_num'][0]) # print(id2lable[result['label_num'][0]]) subject_pinyin = lable_discipline_types[id2lable[str(result['label_num'][0])]] # with open(f"data/prompt/{subject_pinyin}.npy") as : # zidonghua = np.load('data/prompt/{subject_pinyin}.npy') data_subject = np.load(f"data/prompt_qikan/{subject_pinyin}.npy") data_subject_1 = np.load(f"data/prompt_master/{subject_pinyin}.npy") data_subject_2 = np.load(f"data/prompt_doctor/{subject_pinyin}.npy") print("xb.shape", data_subject.shape) print("xb_1.shape", data_subject_1.shape) print("xb_2.shape", data_subject_2.shape) data_subject = np.concatenate((data_subject, data_subject_1, data_subject_2)) print("data_subject.shape", data_subject.shape) index = faiss.read_index(f'data/prompt_qikan_master_doctor_ivf/{subject_pinyin}.ivf') with open(f"data/data_info_qikan/{subject_pinyin}.json") as f: data_info = json.loads(f.read()) with open(f"data/data_info_master/{subject_pinyin}.json") as f: data_info_1 = json.loads(f.read()) with open(f"data/data_info_doctor/{subject_pinyin}.json") as f: data_info_2 = json.loads(f.read()) print(len(data_info)) print(len(data_info_1)) print(len(data_info_2)) data_info = data_info + data_info_1 + data_info_2 print(len(data_info)) print(data_info[0]) index.add(data_subject) # index.nprobe = 2 # default nprobe is 1, try a few more # k = nums k = 20 prompt = "标题:“{}”,摘要:“{}”".format(title, abstract) embs = model.encode([prompt], normalize_embeddings=True) D, I = index.search(embs, int(k)) # print(I) reference_list = [] for i in I[0]: reference_list.append(data_info[i]) data_info = ulit_recall_paper(reference_list, nums) return "200", data_info @app.route("/", methods=["POST"]) def handle_query(): # try: title = request.form.get("title") abstract = "" nums = request.form.get('nums') # content = ulit_request_file(file) status_code, data_info_list = main(title, abstract, nums) if status_code == "400": return_text = {"resilt": "", "probabilities": None, "status_code": 400} else: if status_code == "200": return_text = { "data_info": data_info_list, "probabilities": None, "status_code": 200 } else: return_text = {"resilt": "", "probabilities": None, "status_code": 400} return jsonify(return_text) # 返回结果 if __name__ == "__main__": app.run(host="0.0.0.0", port=17003, threaded=True)