Browse Source

生成参考文献第一次提交

master
majiahui@haimaqingfan.com 1 year ago
commit
62c34defd3
  1. 0
      .idea/.gitignore
  2. 21
      README.md
  3. 207
      accurate_check.py
  4. 220
      flask_api.py
  5. 21
      gunicorn_config.py
  6. 1
      run_api_gunicorn.sh

0
.idea/.gitignore

21
README.md

@ -0,0 +1,21 @@
## 安装环境
```bash
conda create -n your_env_name python=3.8
```
## 启动项目
启动此项目前必须启动 vllm-main 项目
```bash
conda activate llama_paper
bash run_api_gunicorn.sh
```
## 测试
```bash
curl -H "Content-Type: application/json" -X POST -d '{"orderid": "EEAE880E-BE95-11EE-8D23-D5E5C66DD02E"}' http://101.37.83.210:16005/search
```
返回"status_code"不出现 400 则调用成功

207
accurate_check.py

@ -0,0 +1,207 @@
import json
import datetime
import pymysql
import re
import requests
from flask import Flask, jsonify
from flask import request
import uuid
import time
import redis
from threading import Thread
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=8, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query_check_task'
db_key_querying = 'querying_check_task'
db_key_queryset = 'queryset_check_task'
db_key_query_recall = 'query_recall'
def run_query(conn, sql, params):
with conn.cursor() as cursor:
cursor.execute(sql, params)
result = cursor.fetchall()
return result
# def processing_one_text(paper_id):
# conn = pymysql.connect(
# host='192.168.31.145',
# port=3306,
# user='root',
# password='123456',
# db='zhiwang_db',
# charset='utf8mb4',
# cursorclass=pymysql.cursors.DictCursor
# )
#
# sql = 'SELECT * FROM main_table_paper_detail_message WHERE doc_id=%s'
# params = (paper_id,)
#
# result = run_query(conn, sql, params)
#
# conn.close()
# print(result[0]['title'], result[0]['author'])
# title = result[0]['title']
# author = result[0]['author']
# degree = result[0]['degree']
# year = result[0]['content'].split("/")[5]
# content_path = result[0]['content']
# school = result[0]['school']
# qikan_name = result[0]['qikan_name']
# author = str(author).strip(";")
# author = str(author).replace(";", ",")
# # select
# # school, qikan_name
# # from main_table_paper_detail_message limit
# # 10000 \G;;
#
# try:
# with open(content_path, encoding="utf-8") as f:
# text = f.read()
# except:
# with open(content_path, encoding="gbk") as f:
# text = f.read()
#
# paper_info = {
# "title": title,
# "author": author,
# "degree": degree,
# "year": year,
# "paper_len_word": len(text),
# "school": school,
# "qikan_name": qikan_name
# }
# return paper_info
from clickhouse_driver import Client
class PureClient:
def __init__(self, database='test_db'):
# 只需要写本地地址
self.client = Client(host='192.168.31.74', port=9000, user='default',
password='zhicheng123*', database=database)
def run(self, sql):
client = self.client
collection = client.query_dataframe(sql)
return collection
def processing_one_text(paper_id):
pureclient = PureClient()
print("paper_id", paper_id)
sql = 'SELECT * FROM main_paper_message WHERE doc_id={}'.format(paper_id)
result = pureclient.run(sql)
print("result", result)
title = result['title'][0]
author = result['author'][0]
degree = result['degree'][0]
year = result['content'][0].split("/")[5]
school = result['school'][0]
qikan_name = result['qikan_name'][0]
author = str(author).strip(";")
author = str(author).replace(";", ",")
# select
# school, qikan_name
# from main_table_paper_detail_message limit
# 10000 \G;;
paper_info = {
"title": title,
"author": author,
"degree": degree,
"year": year,
"school": school,
"qikan_name": qikan_name
}
print("paper_info", paper_info)
return paper_info
def ulit_recall_paper(recall_data_list_dict):
'''
对返回的十篇文章路径读取并解析
:param recall_data_list_path:
:return data: list [[sentence, filename],[sentence, filename],[sentence, filename]]
'''
# data = []
# for path in recall_data_list_path:
# filename = path.split("/")[-1]
# with open(path, encoding="gbk") as f:
# text = f.read()
# text_list = text.split("\n")
# for sentence in text_list:
# if sentence != "":
# data.append([sentence, filename])
# return data
data = []
for i in list(recall_data_list_dict.items()):
data_one = processing_one_text(i[0])
print("ulit_recall_paper-1")
degree = "[D]"
if data_one['degree'] == "期刊":
degree = "[J]"
# school = result[0]['school']
# qikan_name = result[0]['qikan_name']
if data_one['school'] != " ":
source = data_one['school']
else:
source = data_one['qikan_name']
print("ulit_recall_paper-2")
paper_name = ".".join([data_one['author'], data_one['title'] + degree, ",".join([source, data_one['year']])])
paper_name = paper_name + "."
data.append(paper_name)
print("ulit_recall_paper-3")
data = list(set(data))
return data
def classify_accurate_check():
while True:
if redis_.llen(db_key_query_recall) == 0: # 若队列中没有元素就继续获取
time.sleep(1)
continue
print("计算结果")
query_recall = redis_.lpop(db_key_query_recall).decode('UTF-8') # 获取query的text
query_recall_dict = json.loads(query_recall)
query_recall_uuid = query_recall_dict["uuid"]
recall_data_list_dict = query_recall_dict["data"]
is_success = query_recall_dict["is_success"]
try:
if is_success == "0":
return_text = {"resilt": "宇鹏接口不成功", "probabilities": None, "status_code": 400}
else:
if recall_data_list_dict == "{}":
return_text = {"resilt": "查询结果为空", "probabilities": None, "status_code": 400}
else:
recall_data_list = ulit_recall_paper(recall_data_list_dict)
recall_data = "\n".join(recall_data_list)
return_text = {"resilt": recall_data, "probabilities": None, "status_code": 200}
except:
return_text = {"resilt": "计算有问题", "probabilities": None, "status_code": 400}
load_result_path = "./new_data_logs/{}.json".format(query_recall_uuid)
print("queue_uuid: ", query_recall_uuid)
print("load_result_path: ", load_result_path)
with open(load_result_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(return_text, f2, ensure_ascii=False, indent=4)
redis_.set(query_recall_uuid, load_result_path, 86400)
if __name__ == '__main__':
t1 = Thread(target=classify_accurate_check)
t1.start()

220
flask_api.py

@ -0,0 +1,220 @@
import os
import numpy as np
from numpy.linalg import norm
import json
import datetime
import pymysql
import re
import requests
from flask import Flask, jsonify
from flask import request
import uuid
import time
import redis
from threading import Thread
from multiprocessing import Pool
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=8, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query_check_task'
db_key_querying = 'querying_check_task'
db_key_queryset = 'queryset_check_task'
db_key_query_recall = 'query_recall'
def dialog_line_parse(url, text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
response = requests.post(
url,
json=text,
timeout=100000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text))
print(text)
return {}
def recall_10(queue_uuid, title, abst_zh, content):
'''
宇鹏召回接口
:param paper_name:
:return:
'''
request_json = {
"uuid": queue_uuid,
"title": title,
"abst_zh": abst_zh,
"content": content
}
print(request_json)
a = dialog_line_parse("http://192.168.31.74:50004/check1", request_json)
def uilt_content(content):
zhaiyao_list = ["摘要"]
zhaiyao_en_list = ["Abstract", "abstract"]
mulu_list = ["目录"]
key_word_list = ["关键词"]
caikanwenxian = ["参考文献"]
key_word_bool = False
key_word_str = ""
zhaiyao_bool = False
zhaiyao_en_bool = False
zhaiyao_str = ""
zhaiyao_en_str = ""
mulu_str = ""
zhaiyao_text = ""
mulu_bool = False
pantten_zhaiyao = '(摘\s*要)'
result_biaoti_list = re.findall(pantten_zhaiyao, content)
if len(result_biaoti_list) != 0:
zhaiyao_str = result_biaoti_list[0]
zhaiyao_bool = True
else:
zhaiyao_bool = False
for i in zhaiyao_en_list:
if i in content:
zhaiyao_en_bool = True
zhaiyao_en_str = i
break
for i in mulu_list:
if i in content:
mulu_str = i
mulu_bool = True
break
for i in key_word_list:
if i in content:
key_word_str = i
key_word_bool = True
break
if zhaiyao_bool == True and key_word_bool == True:
pantten_zhaiyao = "{}(.*?){}".format(zhaiyao_str, key_word_str)
result_biaoti_list = re.findall(pantten_zhaiyao, content)
zhaiyao_text = result_biaoti_list[0]
elif zhaiyao_bool == True and zhaiyao_en_bool == True:
pantten_zhaiyao = "{}(.*?){}".format(zhaiyao_str, zhaiyao_en_str)
result_biaoti_list = re.findall(pantten_zhaiyao, content)
zhaiyao_text = result_biaoti_list[0]
elif zhaiyao_bool == True and mulu_bool == True:
pantten_zhaiyao = "{}(.*?){}".format(zhaiyao_str, mulu_str)
result_biaoti_list = re.findall(pantten_zhaiyao, content)
zhaiyao_text = result_biaoti_list[0]
if zhaiyao_text == "":
content = str(content).replace("\n", "")
content_list = content.split("")
zhaiyao_text = "".join(content_list[:15])
return zhaiyao_text
def ulit_request_file(file):
file_name = file.filename
if file_name.split(".")[-1] == "txt":
file_name_save = "data/request/{}".format(file_name)
file.save(file_name_save)
try:
with open(file_name_save, encoding="gbk") as f:
content = f.read()
except:
with open(file_name_save, encoding="utf-8") as f:
content = f.read()
content = " ".join([i for i in content.split("\n") if i != ""])
return content
@app.route("/", methods=["POST"])
def handle_query():
try:
title = request.form.get("title")
# file = request.files.get('file')
abstract = request.form.get('abstract')
nums = request.form.get('nums')
# content = ulit_request_file(file)
content = ""
id_ = str(uuid.uuid1()) # 为query生成唯一标识
print("uuid: ", id_)
print(id_)
d = {
'id': id_,
'abstract': abstract,
'title': title,
'nums': nums
}
# print(d)
# 绑定文本和query id
# recall_10(id_, title, abst_zh, content)
Thread_rellce = Thread(target=recall_10, args=(id_, title, abstract, content,))
Thread_rellce.start()
load_request_path = './request_data_logs/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 indent=2 JSON数据的缩进,美观
json.dump(d, f2, ensure_ascii=False, indent=4)
while True:
result = redis_.get(id_) # 获取该query的模型结果
if result is not None:
redis_.delete(id_)
result_path = result.decode('UTF-8')
break
print("获取结果完成")
with open(result_path, encoding='utf8') as f1:
# 加载文件的对象
result_dict = json.load(f1)
reference = result_dict["resilt"]
status_code = str(result_dict["status_code"])
print("结果分析完成")
print("reference", reference)
if status_code == "400":
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
else:
reference_list = reference.split("\n")
reference_list = reference_list[:int(nums)]
print(reference_list)
reference = [f"[{str(i+1)}]" + reference_list[i] for i in range(len(reference_list))]
if status_code == "200":
return_text = {"resilt": reference, "probabilities": None, "status_code": 200}
else:
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
except:
return_text = {"resilt": "", "probabilities": None, "status_code": 400}
return jsonify(return_text) # 返回结果
if __name__ == "__main__":
app.run(host="0.0.0.0", port=17000, threaded=True)

21
gunicorn_config.py

@ -0,0 +1,21 @@
# 并行工作线程数
workers = 2
# 监听内网端口5000【按需要更改】
bind = '0.0.0.0:17000'
loglevel = 'debug'
worker_class = "gevent"
# 设置守护进程【关闭连接时,程序仍在运行】
daemon = True
# 设置超时时间120s,默认为30s。按自己的需求进行设置
timeout = 120
# 设置访问日志和错误信息日志路径
accesslog = './logs/acess.log'
errorlog = './logs/error.log'
# access_log_format = '%(h) - %(t)s - %(u)s - %(s)s %(H)s'
# errorlog = '-' # 记录到标准输出
# 设置最大并发量
worker_connections = 20000

1
run_api_gunicorn.sh

@ -0,0 +1 @@
gunicorn flask_api:app -c gunicorn_config.py
Loading…
Cancel
Save