From 0a88623330eabf5ffcfa88ef672b91148e58e3d9 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Fri, 26 Jul 2024 18:30:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=8F=82=E8=80=83=E6=96=87?= =?UTF-8?q?=E7=8C=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generate_references_api.py | 188 ------------------------------------ generate_references_api_1.py | 224 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 188 deletions(-) delete mode 100644 generate_references_api.py create mode 100644 generate_references_api_1.py diff --git a/generate_references_api.py b/generate_references_api.py deleted file mode 100644 index 8b2b9ef..0000000 --- a/generate_references_api.py +++ /dev/null @@ -1,188 +0,0 @@ -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "0" -from flask import Flask, jsonify -from flask import request -import numpy as np -import faiss -import json -import requests -import socket -from sentence_transformers import SentenceTransformer - - -with open("data/lable/id2lable.json", encoding="utf-8") as f: - id2lable = json.loads(f.read()) - -with open("data/lable/lable2id.json", encoding="utf-8") as f: - lable2id = json.loads(f.read()) - -with open("data/discipline_types.json") as f: - lable_discipline_types = json.loads(f.read()) - - -app = Flask(__name__) -app.config["JSON_AS_ASCII"] = False - -d = 768 # dimension -model = SentenceTransformer('Dmeta-embedding-zh') - -def get_host_ip(): - """ - 查询本机ip地址 - :return: ip - """ - try: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(('8.8.8.8', 80)) - ip = s.getsockname()[0] - finally: - s.close() - - return ip - -# url = "http://{}:50003/roformer".format(str(get_host_ip())) -url = "http://{}:50003/roformer".format("192.168.31.149") - -def dialog_line_parse(url, text): - """ - 将数据输入模型进行分析并输出结果 - :param url: 模型url - :param text: 进入模型的数据 - :return: 模型返回结果 - """ - - response = requests.post( - url, - json=text, - timeout=1000 - ) - if response.status_code == 200: - return response.json() - else: - # logger.error( - # "【{}】 Failed to get a proper response from remote " - # "server. Status Code: {}. Response: {}" - # "".format(url, response.status_code, response.text) - # ) - print("【{}】 Failed to get a proper response from remote " - "server. Status Code: {}. Response: {}" - "".format(url, response.status_code, response.text)) - print(text) - return [] - -def ulit_recall_paper(reference_list): - ''' - 对返回的十篇文章路径读取并解析 - :param recall_data_list_path: - :return data: list [[sentence, filename],[sentence, filename],[sentence, filename]] - ''' - - # data = [] - # for path in recall_data_list_path: - # filename = path.split("/")[-1] - # with open(path, encoding="gbk") as f: - # text = f.read() - # text_list = text.split("\n") - # for sentence in text_list: - # if sentence != "": - # data.append([sentence, filename]) - # return data - - # recall_data_list - # 作者 论文名称 论文类别 论文来源 论文年份 摘要 - # "[1]赵璐.基于旅游资源开发下的新农村景观营建研究[D].西安建筑科技大学,2014." - data = [] - for data_one in reference_list: - paper = ".".join([ - ",".join([i for i in data_one[0].split(";") if i != ""]), - data_one[1] + "[J]", - ",".join([ - data_one[3], str(data_one[4]) + "." - ]) - ]) - - data.append(paper) - - return data - - -def main(title, abstract, nums): - data = { - "title": title, - "abst_zh": abstract, - "content": "" - } - # { - # "label_num": [ - # 117, - # 143 - # ] - # } - result = dialog_line_parse(url, data) - - # print(result['label_num'][0]) - # print(id2lable[result['label_num'][0]]) - subject_pinyin = lable_discipline_types[id2lable[str(result['label_num'][0])]] - - # with open(f"data/prompt/{subject_pinyin}.npy") as : - # zidonghua = np.load('data/prompt/{subject_pinyin}.npy') - - data_subject = np.load(f"data/prompt_qikan/{subject_pinyin}.npy") - - index = faiss.read_index(f'data/prompt_qikan_ivf/{subject_pinyin}.ivf') - - with open(f"data/data_info_qikan/{subject_pinyin}.json") as f: - data_info = json.loads(f.read()) - - index.add(data_subject) - # index.nprobe = 2 # default nprobe is 1, try a few more - k = nums - prompt = "标题:“{}”,摘要:“{}”".format(title, abstract) - embs = model.encode([prompt], normalize_embeddings=True) - - D, I = index.search(embs, int(k)) - print(I) - - reference_list = [] - abstract_list = [] - for i in I[0]: - reference_list.append(data_info[i]) - abstract_list.append(data_info[i][5]) - - return "200", ulit_recall_paper(reference_list), abstract_list - - -@app.route("/", methods=["POST"]) -def handle_query(): - # try: - title = request.form.get("title") - abstract = request.form.get('abstract') - nums = request.form.get('nums') - - # content = ulit_request_file(file) - - status_code, reference, abstract_list = main(title, abstract, nums) - - if status_code == "400": - return_text = {"resilt": "", "probabilities": None, "status_code": 400} - else: - reference_list = reference - print(reference_list) - reference = [f"[{str(i+1)}]" + reference_list[i] for i in range(len(reference_list))] - if status_code == "200": - return_text = { - "resilt": { - "reference": reference, - "abstract": abstract_list - }, - "probabilities": None, - "status_code": 200 - } - else: - return_text = {"resilt": "", "probabilities": None, "status_code": 400} - # except: - # return_text = {"resilt": "", "probabilities": None, "status_code": 400} - return jsonify(return_text) # 返回结果 - -if __name__ == "__main__": - app.run(host="0.0.0.0", port=17001, threaded=True) \ No newline at end of file diff --git a/generate_references_api_1.py b/generate_references_api_1.py new file mode 100644 index 0000000..451c874 --- /dev/null +++ b/generate_references_api_1.py @@ -0,0 +1,224 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +from flask import Flask, jsonify +from flask import request +import numpy as np +import faiss +import json +import requests +import socket +from sentence_transformers import SentenceTransformer + + +with open("data/lable/id2lable.json", encoding="utf-8") as f: + id2lable = json.loads(f.read()) + +with open("data/lable/lable2id.json", encoding="utf-8") as f: + lable2id = json.loads(f.read()) + +with open("data/discipline_types.json") as f: + lable_discipline_types = json.loads(f.read()) + + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +d = 768 # dimension +model = SentenceTransformer('Dmeta-embedding-zh') + +def get_host_ip(): + """ + 查询本机ip地址 + :return: ip + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + + return ip + +# url = "http://{}:50003/roformer".format(str(get_host_ip())) +url = "http://{}:50003/roformer".format("192.168.31.149") + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=1000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return [] + + +def panduan_paper_lable(paper_lable_text): + paper_lable = { + "硕士": "D", + "期刊": "J", + "博士": "J" + } + return paper_lable[paper_lable_text] + + +def ulit_recall_paper(reference_list, nums): + ''' + 对返回的十篇文章路径读取并解析 + :param recall_data_list_path: + :return data: list [[sentence, filename],[sentence, filename],[sentence, filename]] + ''' + + # data = [] + # for path in recall_data_list_path: + # filename = path.split("/")[-1] + # with open(path, encoding="gbk") as f: + # text = f.read() + # text_list = text.split("\n") + # for sentence in text_list: + # if sentence != "": + # data.append([sentence, filename]) + # return data + + # recall_data_list + # 作者 论文名称 论文类别 论文来源 论文年份 摘要 期刊 + # "[1]赵璐.基于旅游资源开发下的新农村景观营建研究[D].西安建筑科技大学,2014." + + data = [] + for data_one in reference_list: + print("data_one", data_one) + print("data_one[0]", data_one[0]) + paper = ".".join([ + ",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]), + data_one[1] + f"[{panduan_paper_lable(data_one[6])}]", + ",".join([ + data_one[3], str(data_one[4]) + "." + ]) + ]) + + data.append(paper) + + # print(data) + data = list(set(data)) + print(len(data)) + print(data[0]) + print(nums) + data = data[:int(nums)] + return data + + +def main(title, abstract, nums): + data = { + "title": title, + "abst_zh": abstract, + "content": "" + } + # { + # "label_num": [ + # 117, + # 143 + # ] + # } + result = dialog_line_parse(url, data) + + # print(result['label_num'][0]) + # print(id2lable[result['label_num'][0]]) + subject_pinyin = lable_discipline_types[id2lable[str(result['label_num'][0])]] + + # with open(f"data/prompt/{subject_pinyin}.npy") as : + # zidonghua = np.load('data/prompt/{subject_pinyin}.npy') + + data_subject = np.load(f"data/prompt_qikan/{subject_pinyin}.npy") + data_subject_1 = np.load(f"data/prompt_master/{subject_pinyin}.npy") + data_subject_2 = np.load(f"data/prompt_doctor/{subject_pinyin}.npy") + print("xb.shape", data_subject.shape) + print("xb_1.shape", data_subject_1.shape) + print("xb_2.shape", data_subject_2.shape) + data_subject = np.concatenate((data_subject, data_subject_1, data_subject_2)) + print("data_subject.shape", data_subject.shape) + + index = faiss.read_index(f'data/prompt_qikan_master_doctor_ivf/{subject_pinyin}.ivf') + + with open(f"data/data_info_qikan/{subject_pinyin}.json") as f: + data_info = json.loads(f.read()) + + with open(f"data/data_info_master/{subject_pinyin}.json") as f: + data_info_1 = json.loads(f.read()) + + with open(f"data/data_info_doctor/{subject_pinyin}.json") as f: + data_info_2 = json.loads(f.read()) + + print(len(data_info)) + print(len(data_info_1)) + print(len(data_info_2)) + data_info = data_info + data_info_1 + data_info_2 + print(len(data_info)) + print(data_info[0]) + index.add(data_subject) + # index.nprobe = 2 # default nprobe is 1, try a few more + # k = nums + k = 50 + prompt = "标题:“{}”,摘要:“{}”".format(title, abstract) + embs = model.encode([prompt], normalize_embeddings=True) + + D, I = index.search(embs, int(k)) + print(I) + + reference_list = [] + abstract_list = [] + for i in I[0]: + reference_list.append(data_info[i]) + abstract_list.append(data_info[i][5]) + + return "200", ulit_recall_paper(reference_list, nums), abstract_list + + +@app.route("/", methods=["POST"]) +def handle_query(): + # try: + title = request.form.get("title") + abstract = request.form.get('abstract') + nums = request.form.get('nums') + + # content = ulit_request_file(file) + + status_code, reference, abstract_list = main(title, abstract, nums) + + if status_code == "400": + return_text = {"resilt": "", "probabilities": None, "status_code": 400} + else: + reference_list = reference + print(reference_list) + reference = [f"[{str(i+1)}]" + reference_list[i] for i in range(len(reference_list))] + if status_code == "200": + return_text = { + "resilt": reference, + "probabilities": None, + "status_code": 200 + } + else: + return_text = {"resilt": "", "probabilities": None, "status_code": 400} + # except: + # return_text = {"resilt": "", "probabilities": None, "status_code": 400} + return jsonify(return_text) # 返回结果 + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=17001, threaded=True) \ No newline at end of file