From 62c34defd33ff5df2580af1cd9ed197420d56a81 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Tue, 30 Jan 2024 10:29:06 +0800 Subject: [PATCH] =?UTF-8?q?=E7=94=9F=E6=88=90=E5=8F=82=E8=80=83=E6=96=87?= =?UTF-8?q?=E7=8C=AE=E7=AC=AC=E4=B8=80=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .idea/.gitignore | 0 README.md | 21 +++++ accurate_check.py | 207 ++++++++++++++++++++++++++++++++++++++++++++++++ flask_api.py | 220 ++++++++++++++++++++++++++++++++++++++++++++++++++++ gunicorn_config.py | 21 +++++ run_api_gunicorn.sh | 1 + 6 files changed, 470 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 README.md create mode 100644 accurate_check.py create mode 100644 flask_api.py create mode 100644 gunicorn_config.py create mode 100644 run_api_gunicorn.sh diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/README.md b/README.md new file mode 100644 index 0000000..37d4fc7 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +## 安装环境 + +```bash +conda create -n your_env_name python=3.8 +``` + +## 启动项目 +启动此项目前必须启动 vllm-main 项目 + +```bash +conda activate llama_paper +bash run_api_gunicorn.sh +``` + +## 测试 + +```bash +curl -H "Content-Type: application/json" -X POST -d '{"orderid": "EEAE880E-BE95-11EE-8D23-D5E5C66DD02E"}' http://101.37.83.210:16005/search +``` + +返回"status_code"不出现 400 则调用成功 \ No newline at end of file diff --git a/accurate_check.py b/accurate_check.py new file mode 100644 index 0000000..5022785 --- /dev/null +++ b/accurate_check.py @@ -0,0 +1,207 @@ +import json +import datetime +import pymysql +import re +import requests +from flask import Flask, jsonify +from flask import request +import uuid +import time +import redis +from threading import Thread + +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=8, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query_check_task' +db_key_querying = 'querying_check_task' +db_key_queryset = 'queryset_check_task' +db_key_query_recall = 'query_recall' + + +def run_query(conn, sql, params): + with conn.cursor() as cursor: + cursor.execute(sql, params) + result = cursor.fetchall() + return result + + +# def processing_one_text(paper_id): +# conn = pymysql.connect( +# host='192.168.31.145', +# port=3306, +# user='root', +# password='123456', +# db='zhiwang_db', +# charset='utf8mb4', +# cursorclass=pymysql.cursors.DictCursor +# ) +# +# sql = 'SELECT * FROM main_table_paper_detail_message WHERE doc_id=%s' +# params = (paper_id,) +# +# result = run_query(conn, sql, params) +# +# conn.close() +# print(result[0]['title'], result[0]['author']) +# title = result[0]['title'] +# author = result[0]['author'] +# degree = result[0]['degree'] +# year = result[0]['content'].split("/")[5] +# content_path = result[0]['content'] +# school = result[0]['school'] +# qikan_name = result[0]['qikan_name'] +# author = str(author).strip(";") +# author = str(author).replace(";", ",") +# # select +# # school, qikan_name +# # from main_table_paper_detail_message limit +# # 10000 \G;; +# +# try: +# with open(content_path, encoding="utf-8") as f: +# text = f.read() +# except: +# with open(content_path, encoding="gbk") as f: +# text = f.read() +# +# paper_info = { +# "title": title, +# "author": author, +# "degree": degree, +# "year": year, +# "paper_len_word": len(text), +# "school": school, +# "qikan_name": qikan_name +# } +# return paper_info + +from clickhouse_driver import Client + +class PureClient: + def __init__(self, database='test_db'): + # 只需要写本地地址 + self.client = Client(host='192.168.31.74', port=9000, user='default', + password='zhicheng123*', database=database) + + def run(self, sql): + client = self.client + collection = client.query_dataframe(sql) + return collection + +def processing_one_text(paper_id): + + pureclient = PureClient() + print("paper_id", paper_id) + sql = 'SELECT * FROM main_paper_message WHERE doc_id={}'.format(paper_id) + result = pureclient.run(sql) + print("result", result) + title = result['title'][0] + author = result['author'][0] + degree = result['degree'][0] + year = result['content'][0].split("/")[5] + school = result['school'][0] + qikan_name = result['qikan_name'][0] + author = str(author).strip(";") + author = str(author).replace(";", ",") + # select + # school, qikan_name + # from main_table_paper_detail_message limit + # 10000 \G;; + + paper_info = { + "title": title, + "author": author, + "degree": degree, + "year": year, + "school": school, + "qikan_name": qikan_name + } + print("paper_info", paper_info) + return paper_info + + +def ulit_recall_paper(recall_data_list_dict): + ''' + 对返回的十篇文章路径读取并解析 + :param recall_data_list_path: + :return data: list [[sentence, filename],[sentence, filename],[sentence, filename]] + ''' + + # data = [] + # for path in recall_data_list_path: + # filename = path.split("/")[-1] + # with open(path, encoding="gbk") as f: + # text = f.read() + # text_list = text.split("\n") + # for sentence in text_list: + # if sentence != "": + # data.append([sentence, filename]) + # return data + + data = [] + for i in list(recall_data_list_dict.items()): + data_one = processing_one_text(i[0]) + + print("ulit_recall_paper-1") + degree = "[D]" + if data_one['degree'] == "期刊": + degree = "[J]" + + # school = result[0]['school'] + # qikan_name = result[0]['qikan_name'] + if data_one['school'] != " ": + source = data_one['school'] + else: + source = data_one['qikan_name'] + print("ulit_recall_paper-2") + paper_name = ".".join([data_one['author'], data_one['title'] + degree, ",".join([source, data_one['year']])]) + paper_name = paper_name + "." + data.append(paper_name) + print("ulit_recall_paper-3") + data = list(set(data)) + return data + + +def classify_accurate_check(): + while True: + if redis_.llen(db_key_query_recall) == 0: # 若队列中没有元素就继续获取 + time.sleep(1) + continue + + print("计算结果") + query_recall = redis_.lpop(db_key_query_recall).decode('UTF-8') # 获取query的text + query_recall_dict = json.loads(query_recall) + + query_recall_uuid = query_recall_dict["uuid"] + recall_data_list_dict = query_recall_dict["data"] + is_success = query_recall_dict["is_success"] + + try: + if is_success == "0": + return_text = {"resilt": "宇鹏接口不成功", "probabilities": None, "status_code": 400} + else: + if recall_data_list_dict == "{}": + return_text = {"resilt": "查询结果为空", "probabilities": None, "status_code": 400} + else: + recall_data_list = ulit_recall_paper(recall_data_list_dict) + recall_data = "\n".join(recall_data_list) + return_text = {"resilt": recall_data, "probabilities": None, "status_code": 200} + except: + return_text = {"resilt": "计算有问题", "probabilities": None, "status_code": 400} + + load_result_path = "./new_data_logs/{}.json".format(query_recall_uuid) + print("queue_uuid: ", query_recall_uuid) + print("load_result_path: ", load_result_path) + + with open(load_result_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(return_text, f2, ensure_ascii=False, indent=4) + + redis_.set(query_recall_uuid, load_result_path, 86400) + + +if __name__ == '__main__': + t1 = Thread(target=classify_accurate_check) + t1.start() \ No newline at end of file diff --git a/flask_api.py b/flask_api.py new file mode 100644 index 0000000..6c52f88 --- /dev/null +++ b/flask_api.py @@ -0,0 +1,220 @@ +import os +import numpy as np +from numpy.linalg import norm +import json +import datetime +import pymysql +import re +import requests +from flask import Flask, jsonify +from flask import request +import uuid +import time +import redis +from threading import Thread +from multiprocessing import Pool + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=8, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query_check_task' +db_key_querying = 'querying_check_task' +db_key_queryset = 'queryset_check_task' +db_key_query_recall = 'query_recall' + + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=100000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return {} + + +def recall_10(queue_uuid, title, abst_zh, content): + ''' + 宇鹏召回接口 + :param paper_name: + :return: + ''' + + request_json = { + "uuid": queue_uuid, + "title": title, + "abst_zh": abst_zh, + "content": content + } + print(request_json) + a = dialog_line_parse("http://192.168.31.74:50004/check1", request_json) + + +def uilt_content(content): + zhaiyao_list = ["摘要"] + zhaiyao_en_list = ["Abstract", "abstract"] + mulu_list = ["目录"] + key_word_list = ["关键词"] + caikanwenxian = ["参考文献"] + key_word_bool = False + key_word_str = "" + zhaiyao_bool = False + zhaiyao_en_bool = False + zhaiyao_str = "" + zhaiyao_en_str = "" + mulu_str = "" + zhaiyao_text = "" + mulu_bool = False + + pantten_zhaiyao = '(摘\s*要)' + result_biaoti_list = re.findall(pantten_zhaiyao, content) + if len(result_biaoti_list) != 0: + zhaiyao_str = result_biaoti_list[0] + zhaiyao_bool = True + else: + zhaiyao_bool = False + + for i in zhaiyao_en_list: + if i in content: + zhaiyao_en_bool = True + zhaiyao_en_str = i + break + + for i in mulu_list: + if i in content: + mulu_str = i + mulu_bool = True + break + + for i in key_word_list: + if i in content: + key_word_str = i + key_word_bool = True + break + + if zhaiyao_bool == True and key_word_bool == True: + pantten_zhaiyao = "{}(.*?){}".format(zhaiyao_str, key_word_str) + result_biaoti_list = re.findall(pantten_zhaiyao, content) + zhaiyao_text = result_biaoti_list[0] + + elif zhaiyao_bool == True and zhaiyao_en_bool == True: + pantten_zhaiyao = "{}(.*?){}".format(zhaiyao_str, zhaiyao_en_str) + result_biaoti_list = re.findall(pantten_zhaiyao, content) + zhaiyao_text = result_biaoti_list[0] + + elif zhaiyao_bool == True and mulu_bool == True: + pantten_zhaiyao = "{}(.*?){}".format(zhaiyao_str, mulu_str) + result_biaoti_list = re.findall(pantten_zhaiyao, content) + zhaiyao_text = result_biaoti_list[0] + + if zhaiyao_text == "": + content = str(content).replace("。\n", "。") + content_list = content.split("。") + zhaiyao_text = "".join(content_list[:15]) + return zhaiyao_text + + +def ulit_request_file(file): + file_name = file.filename + if file_name.split(".")[-1] == "txt": + file_name_save = "data/request/{}".format(file_name) + file.save(file_name_save) + try: + with open(file_name_save, encoding="gbk") as f: + content = f.read() + except: + with open(file_name_save, encoding="utf-8") as f: + content = f.read() + + content = " ".join([i for i in content.split("\n") if i != ""]) + + return content + + +@app.route("/", methods=["POST"]) +def handle_query(): + try: + title = request.form.get("title") + # file = request.files.get('file') + abstract = request.form.get('abstract') + nums = request.form.get('nums') + + # content = ulit_request_file(file) + content = "" + + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + print("uuid: ", id_) + print(id_) + d = { + 'id': id_, + 'abstract': abstract, + 'title': title, + 'nums': nums + } + # print(d) + # 绑定文本和query id + # recall_10(id_, title, abst_zh, content) + + Thread_rellce = Thread(target=recall_10, args=(id_, title, abstract, content,)) + Thread_rellce.start() + + load_request_path = './request_data_logs/{}.json'.format(id_) + with open(load_request_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 indent=2 JSON数据的缩进,美观 + json.dump(d, f2, ensure_ascii=False, indent=4) + + while True: + result = redis_.get(id_) # 获取该query的模型结果 + if result is not None: + redis_.delete(id_) + result_path = result.decode('UTF-8') + break + + print("获取结果完成") + with open(result_path, encoding='utf8') as f1: + # 加载文件的对象 + result_dict = json.load(f1) + reference = result_dict["resilt"] + status_code = str(result_dict["status_code"]) + + print("结果分析完成") + print("reference", reference) + if status_code == "400": + return_text = {"resilt": "", "probabilities": None, "status_code": 400} + else: + reference_list = reference.split("\n") + reference_list = reference_list[:int(nums)] + print(reference_list) + reference = [f"[{str(i+1)}]" + reference_list[i] for i in range(len(reference_list))] + if status_code == "200": + return_text = {"resilt": reference, "probabilities": None, "status_code": 200} + else: + return_text = {"resilt": "", "probabilities": None, "status_code": 400} + except: + return_text = {"resilt": "", "probabilities": None, "status_code": 400} + return jsonify(return_text) # 返回结果 + + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=17000, threaded=True) \ No newline at end of file diff --git a/gunicorn_config.py b/gunicorn_config.py new file mode 100644 index 0000000..e50ebe5 --- /dev/null +++ b/gunicorn_config.py @@ -0,0 +1,21 @@ +# 并行工作线程数 +workers = 2 +# 监听内网端口5000【按需要更改】 +bind = '0.0.0.0:17000' + +loglevel = 'debug' + +worker_class = "gevent" +# 设置守护进程【关闭连接时,程序仍在运行】 +daemon = True +# 设置超时时间120s,默认为30s。按自己的需求进行设置 +timeout = 120 +# 设置访问日志和错误信息日志路径 +accesslog = './logs/acess.log' +errorlog = './logs/error.log' +# access_log_format = '%(h) - %(t)s - %(u)s - %(s)s %(H)s' +# errorlog = '-' # 记录到标准输出 + + +# 设置最大并发量 +worker_connections = 20000 diff --git a/run_api_gunicorn.sh b/run_api_gunicorn.sh new file mode 100644 index 0000000..3536f48 --- /dev/null +++ b/run_api_gunicorn.sh @@ -0,0 +1 @@ +gunicorn flask_api:app -c gunicorn_config.py \ No newline at end of file