13 changed files with 1013 additions and 0 deletions
File diff suppressed because one or more lines are too long
@ -0,0 +1,254 @@ |
|||
# -*- coding = utf-8 -*- |
|||
# @Time: 18:01 |
|||
# @Author:ZYP |
|||
# @File:SearchSimPaper.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import gc |
|||
import time |
|||
|
|||
# ========================================================================================= |
|||
# 查找相似文档 |
|||
# · 文档之间关键词进行取交集 |
|||
# · 再对选取出的文档与送检文档进行关键词之间的相似度计算 |
|||
# · 最终选出最相似的文档,进行排序返回 |
|||
# ========================================================================================= |
|||
|
|||
import math |
|||
import numpy as np |
|||
from collections import defaultdict |
|||
from pymysql.converters import escape_string |
|||
from sklearn.metrics.pairwise import cosine_similarity |
|||
from util import cut_text, l2_normal, get_word_vec |
|||
|
|||
|
|||
def load_inverted_table(class_list, mysql, log): |
|||
"""根据分类结果,将每个类别的倒排表进行聚合,返回一个几个类别的字典、几个类别库中总论文数量""" |
|||
# 记录总的倒排表 {word:[doc_id1,doc_id2,doc_id3, ...]} |
|||
total_inverted_dict1 = {} |
|||
# 记录每个类别的论文数量的和 |
|||
total_nums1 = 0 |
|||
|
|||
for label_num in class_list: |
|||
conn, cursor = mysql.open() |
|||
select_sql = """ |
|||
select word, paper_doc_id from word_map_paper_{}; |
|||
""".format(str(label_num)) |
|||
|
|||
s_time1 = time.time() |
|||
cursor.execute(select_sql) |
|||
for word, paper_doc_id in cursor.fetchall(): |
|||
if word not in total_inverted_dict1.keys(): |
|||
total_inverted_dict1[word] = paper_doc_id |
|||
else: |
|||
# total_inverted_dict1[word] = ','.join( |
|||
# set(total_inverted_dict1[word].split(',') + paper_doc_id.split(','))) |
|||
total_inverted_dict1[word] = total_inverted_dict1[word] + ',' + paper_doc_id |
|||
|
|||
e_time1 = time.time() |
|||
log.log('查找{}类倒排表花费的时间:{}s'.format(str(label_num), e_time1 - s_time1)) |
|||
|
|||
s_time2 = time.time() |
|||
select_paper_num_sql = """ |
|||
select count_number from count_map where label_num={}; |
|||
""".format(str(label_num)) |
|||
|
|||
cursor.execute(select_paper_num_sql) |
|||
|
|||
for nums in cursor.fetchall(): |
|||
total_nums1 += int(nums[0]) |
|||
|
|||
e_time2 = time.time() |
|||
log.log('查找{}类别下数据量花费:{}s'.format(str(label_num), e_time2 - s_time2)) |
|||
|
|||
mysql.close(cursor, conn) |
|||
|
|||
return total_inverted_dict1, total_nums1 |
|||
|
|||
|
|||
def select_sim_doc_message(sim_doc1, mysql): |
|||
""" |
|||
通过相似的 doc_id 在库中查找相关的信息,然后计算每个 doc_id 的均值文档向量,以字典形式返回 {文档号:均值文档向量....} |
|||
:param sim_doc1: 相似文档的列表,[doc_id1, doc_id2, ...] |
|||
:return: 返回{doc_id:(doc_avg_vec, doc_path)} |
|||
""" |
|||
all_paper_vec_dict = {} |
|||
conn, cursor = mysql.open() |
|||
for doc_id in sim_doc1: |
|||
select_sql = """ |
|||
select tb1.doc_id, tb1.title, tb1.abst_zh, tb2.vsm, tb1.content from |
|||
( |
|||
(select doc_id, title, abst_zh, content from main_table_paper_detail_message) tb1 |
|||
left join |
|||
(select doc_id, vsm from id_keywords_weights) tb2 |
|||
on |
|||
tb1.doc_id=tb2.doc_id |
|||
)where tb1.doc_id="{}"; |
|||
""".format( |
|||
escape_string(doc_id)) |
|||
|
|||
cursor.execute(select_sql) |
|||
|
|||
sim_doc_id, sim_title, sim_abst, sim_vsm, sim_content_path = cursor.fetchone() |
|||
sim_vsm_dict = {weight.split('@#$@')[0]: float(weight.split('@#$@')[1]) for weight in sim_vsm.split('&*^%')} |
|||
vector_paper = [] |
|||
value_sum = 0.0 |
|||
for word, weight in sim_vsm_dict.items(): |
|||
if word in sim_title: |
|||
value = 0.5 * weight |
|||
elif word in sim_abst: |
|||
value = 0.3 * weight |
|||
else: |
|||
value = 0.2 * weight |
|||
|
|||
word_vec = get_word_vec(word) |
|||
if isinstance(word_vec, int): |
|||
continue |
|||
vector_paper.append(word_vec * value) |
|||
value_sum += value |
|||
|
|||
del sim_vsm_dict |
|||
gc.collect() |
|||
|
|||
# 求一篇文档的关键词的向量均值 |
|||
# avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / len(vector_paper)) |
|||
avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / value_sum) |
|||
all_paper_vec_dict[doc_id] = (avg_vector, sim_content_path) |
|||
|
|||
mysql.close(cursor, conn) |
|||
|
|||
return all_paper_vec_dict |
|||
|
|||
|
|||
def submit_paper_avg_vec(paper_dict1, tf_weight_dict): |
|||
"""根据送检的文档的 tf 值,计算这篇文档的均值向量,以 numpy 数组形式返回""" |
|||
vector_paper = [] |
|||
value_sum = 0.0 |
|||
for word, weight in tf_weight_dict.items(): |
|||
if word in paper_dict1['title']: |
|||
value = 0.5 * weight |
|||
elif word in paper_dict1['abst_zh']: |
|||
value = 0.3 * weight |
|||
else: |
|||
value = 0.2 * weight |
|||
|
|||
word_vec = get_word_vec(word) |
|||
if isinstance(word_vec, int): |
|||
continue |
|||
|
|||
vector_paper.append(word_vec * value) |
|||
value_sum += value |
|||
|
|||
# avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / len(vector_paper)) |
|||
avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / value_sum) |
|||
|
|||
return avg_vector |
|||
|
|||
|
|||
def compare_sim_in_papers(check_vector, sim_message, top=40): |
|||
""" |
|||
计算文档间的相似度,使用的是余弦相似度 |
|||
:param check_vector: 送检文章的文本向量 |
|||
:param sim_message: 待检测的 50 篇相似文档,以字典形式存储 |
|||
:param top: 设置返回最相似的 N 篇文档 |
|||
:return: 返回相似文档的字典 形式:{doc_id:(相似得分, 文档路径)} |
|||
""" |
|||
sim_res_dict = {} |
|||
for doc_id, (vector, content_path) in sim_message.items(): |
|||
# sim_res_dict[doc_id] = cosine_similarity([scalar(check_vector), scalar(vector)])[0][1] |
|||
sim_res_dict[doc_id] = (str(cosine_similarity([check_vector, vector])[0][1]), content_path) |
|||
_ = sorted(sim_res_dict.items(), key=lambda x: float(x[1][0]), reverse=True) |
|||
return {key: value for key, value in _[:top]} |
|||
|
|||
|
|||
def search_sim_paper(paper_dict, class_list, mysql, log, top=100): |
|||
""" |
|||
根据送检论文的字典,在库中进行相似文档的查询,最后返回最相似的 top 文章,用于逐句查重。 |
|||
:param paper_dict: 处理好的格式化送检论文 |
|||
:param class_list: 模型预测送检论文的类别 id 的列表 |
|||
:param top: 返回前 top 个文档 |
|||
:return: 返回相似文档的字典 形式:{doc_id:(相似得分, 文档路径)} |
|||
""" |
|||
all_str = paper_dict['title'] + '。' + paper_dict['abst_zh'] + '。' + paper_dict['content'] |
|||
# 合并倒排表,并统计 论文总量 total_inverted_dict:总的倒排表 |
|||
s0 = time.time() |
|||
total_inverted_dict, total_nums = load_inverted_table(class_list, mysql, log) |
|||
e0 = time.time() |
|||
log.log('查询倒排表花费时间为:{}s'.format(e0 - s0)) |
|||
|
|||
s1 = time.time() |
|||
# 计算送检文档的词频字典{word1:fre1, word2:fre2, ...} |
|||
word_fre_dict = cut_text(all_str, tokenizer='jieba') |
|||
e1 = time.time() |
|||
log.log('切词时间为:{}s'.format(e1 - s1)) |
|||
|
|||
s2 = time.time() |
|||
# 计算送检文档所有词语的 tf-idf 值 |
|||
tf_idf_dict = {} |
|||
for word, freq in word_fre_dict.items(): |
|||
if freq <= 2: |
|||
continue |
|||
tf = freq / sum(word_fre_dict.values()) |
|||
if word in total_inverted_dict.keys(): |
|||
idf = math.log(total_nums / (len(set(total_inverted_dict[word].split(','))) + 1)) |
|||
else: |
|||
idf = math.log(total_nums / 1) |
|||
|
|||
tf_idf = tf * idf |
|||
tf_idf_dict[word] = tf_idf |
|||
e2 = time.time() |
|||
log.log('计算送检文档关键词的TF-idf值花费的时间为:{}s'.format(e2 - s2)) |
|||
|
|||
s3 = time.time() |
|||
# 前 15 的单词、权重 |
|||
tf_dict = l2_normal(tf_idf_dict) |
|||
e3 = time.time() |
|||
log.log('权重正则化花费的时间为:{}s'.format(e3 - s3)) |
|||
|
|||
s4 = time.time() |
|||
# 统计交集的 |
|||
count_words_num = defaultdict(int) |
|||
for word, weight in tf_dict.items(): |
|||
if word in total_inverted_dict.keys(): |
|||
for doc_id in set(total_inverted_dict[word].split(',')): |
|||
count_words_num[doc_id] += 1 |
|||
else: |
|||
continue |
|||
e4 = time.time() |
|||
log.log('统计doc_id交集花费的时间为:{}s'.format(e4 - s4)) |
|||
|
|||
# 排序 |
|||
count_word_num = {i: j for i, j in sorted(count_words_num.items(), key=lambda x: x[1], reverse=True)} |
|||
|
|||
# 查找前 200 篇相似的文档 |
|||
sim_doc = list(count_word_num.keys())[:200] |
|||
|
|||
# 计算这 200 篇文档的 文档均值向量 |
|||
s_time1 = time.time() |
|||
sim_paper_vec_dict = select_sim_doc_message(sim_doc, mysql) |
|||
e_time1 = time.time() |
|||
log.log('计算200篇均值向量所花费的时间为:{}s'.format(e_time1 - s_time1)) |
|||
|
|||
# 计算送检文档的 文档均值向量 |
|||
s_time2 = time.time() |
|||
submit_vec = submit_paper_avg_vec(paper_dict, tf_dict) |
|||
e_time2 = time.time() |
|||
log.log('计算送检文档的均值向量所花费的时间为:{}s'.format(e_time2 - s_time2)) |
|||
|
|||
# 计算送检文档 和 查出来的文档的相似度 并排序, 取 top 10 文章用作整篇查重 |
|||
s_time3 = time.time() |
|||
sim_paper_dict = compare_sim_in_papers(submit_vec, sim_paper_vec_dict, top=top) |
|||
e_time3 = time.time() |
|||
log.log('计算送检文档和查出的文档的相似度(并排序)所花费的时间为:{}s'.format(e_time3 - s_time3)) |
|||
|
|||
del total_inverted_dict |
|||
del total_nums |
|||
del submit_vec |
|||
del sim_paper_vec_dict |
|||
del count_word_num |
|||
del sim_doc |
|||
del word_fre_dict |
|||
|
|||
gc.collect() |
|||
|
|||
return sim_paper_dict |
@ -0,0 +1,85 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/29 18:58 |
|||
# @Author:ZYP |
|||
# @File:demo01_multiprocessimg.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import json |
|||
import multiprocessing |
|||
import os |
|||
import time |
|||
import uuid |
|||
import signal |
|||
from util import Logging |
|||
|
|||
import redis |
|||
import requests |
|||
|
|||
from SearchSimPaper import search_sim_paper |
|||
from mysql_collect import mysql |
|||
|
|||
pool = redis.ConnectionPool(host='192.168.31.145', port=63179, max_connections=50, password='zhicheng123*', db=8) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
db_key_query = 'query' |
|||
|
|||
log = Logging() |
|||
|
|||
|
|||
def check_main_func(uuid_): |
|||
while True: |
|||
if redis_.llen(db_key_query) == 0: |
|||
continue |
|||
while True: |
|||
ss = redis_.rpop(db_key_query) |
|||
if ss is None: |
|||
time.sleep(2) |
|||
else: |
|||
break |
|||
|
|||
paper_data = json.loads(ss.decode()) |
|||
id_ = paper_data["id"] |
|||
message_dict = paper_data["paper"] |
|||
class_res = \ |
|||
json.loads(requests.post('http://192.168.31.145:50003/roformer', data=json.dumps(message_dict)).text)[ |
|||
'label_num'] |
|||
# class_res = [117, 36, 81] |
|||
|
|||
sim_paper_id_dict = search_sim_paper(message_dict, class_res, mysql, log) |
|||
redis_.set(id_, json.dumps(sim_paper_id_dict)) |
|||
|
|||
pid = redis_.hget('process_pid', uuid_) |
|||
log.log("这个进程的 uuid 为:", uuid_) |
|||
redis_.hdel("process_pid", uuid_) |
|||
os.kill(int(pid), signal.SIGTERM) |
|||
break |
|||
|
|||
|
|||
# def set_process(): |
|||
# name = str(uuid.uuid1()) |
|||
# process = multiprocessing.Process(target=check_main_func, args=(), name=name) |
|||
# process.start() |
|||
# process.join() |
|||
# return name, process |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
# while redis_.llen('process_num') < 4: |
|||
# name, process = set_process() |
|||
# redis_.lpush('process_num', name) |
|||
# process.is_alive() |
|||
|
|||
# pool = multiprocessing.Pool(processes=4) |
|||
# while True: |
|||
# if redis_.llen('process_num') < 4: |
|||
# redis_.lpush('process_num', '1') |
|||
# pool.apply_async(check_main_func, args=()) |
|||
|
|||
# pool = multiprocessing.Pool(processes=4) |
|||
while True: |
|||
if redis_.hlen('process_pid') < 4: |
|||
uuid_ = str(uuid.uuid1()) |
|||
process = multiprocessing.Process(target=check_main_func, args=(uuid_,)) |
|||
process.start() |
|||
process_id = process.pid |
|||
redis_.hset('process_pid', uuid_, str(process_id)) |
@ -0,0 +1,110 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/28 15:03 |
|||
# @Author:ZYP |
|||
# @File:demo01_test_redis.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# # @Software: PyCharm |
|||
import time |
|||
|
|||
import flask |
|||
import redis |
|||
import uuid |
|||
import json |
|||
import requests |
|||
# from util import Logging |
|||
# from threading import Thread |
|||
# from mysql_collect import mysql |
|||
# from SearchSimPaper import search_sim_paper |
|||
# import jieba |
|||
from flask import request |
|||
|
|||
# import multiprocessing |
|||
|
|||
app_check = flask.Flask(__name__) |
|||
pool = redis.ConnectionPool(host='192.168.31.145', port=63179, max_connections=50, password='zhicheng123*', db=8) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
# pool1 = redis.ConnectionPool(host='192.168.31.145', port=63179, max_connections=50, password='zhicheng123*', db=11) |
|||
# redis_1 = redis.Redis(connection_pool=pool1, decode_responses=True) |
|||
|
|||
# jieba.initialize() |
|||
|
|||
db_key_query = 'query' |
|||
|
|||
|
|||
# db_key_result = 'result' |
|||
|
|||
# log = Logging() |
|||
|
|||
|
|||
# def check_main_func(): |
|||
# while True: |
|||
# if redis_.llen(db_key_query) == 0: |
|||
# continue |
|||
# while True: |
|||
# ss = redis_.rpop(db_key_query) |
|||
# if ss is None: |
|||
# time.sleep(2) |
|||
# else: |
|||
# break |
|||
# |
|||
# paper_data = json.loads(ss.decode()) |
|||
# id_ = paper_data["id"] |
|||
# message_dict = paper_data["paper"] |
|||
# class_res = \ |
|||
# json.loads(requests.post('http://192.168.31.145:50003/roformer', data=json.dumps(message_dict)).text)[ |
|||
# 'label_num'] |
|||
# |
|||
# sim_paper_id_dict = search_sim_paper(message_dict, class_res, mysql, log) |
|||
# redis_.set(id_, json.dumps(sim_paper_id_dict)) |
|||
|
|||
|
|||
@app_check.route("/check", methods=["POST"]) |
|||
def handle_query(): |
|||
s = time.time() |
|||
message_dict = json.loads(request.data.decode()) |
|||
uuid_request = str(message_dict['uuid']) |
|||
id_ = str(uuid.uuid1()) # 为query生成唯一标识 |
|||
d = {'id': id_, 'paper': message_dict} # 绑定文本和query id |
|||
redis_.rpush(db_key_query, json.dumps(d)) |
|||
while True: |
|||
result = redis_.get(id_) |
|||
if result is not None: |
|||
redis_.delete(id_) |
|||
result_text = {'uuid': uuid_request, 'data': result.decode('UTF-8')} |
|||
# result_text = json.loads(result.decode('UTF-8')) |
|||
break |
|||
e = time.time() |
|||
print('{} 花费了{} s 的时间'.format(uuid_request, (e - s))) |
|||
|
|||
redis_.lpush('query_recall', json.dumps(result_text)) |
|||
|
|||
return uuid_request # 返回结果 |
|||
|
|||
|
|||
# # return '1' |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
# for i in range(2): |
|||
# t = Thread(target=check_main_func, args=()) |
|||
# t.start() |
|||
# processes = [] |
|||
# |
|||
# # 创建并启动多个进程 |
|||
# for i in range(2): |
|||
# process = multiprocessing.Process(target=check_main_func, args=()) |
|||
# processes.append(process) |
|||
# process.start() |
|||
|
|||
app_check.run(debug=False, host='0.0.0.0', port=50004) |
|||
|
|||
# res = redis_.rpop(db_key_query) |
|||
# print(res) |
|||
|
|||
# id_ = "51bc72dc-464e-11ee-baf3-45147420c4fb" |
|||
# res = redis_.get(id_) |
|||
# if res is not None: |
|||
# redis_.delete(id_) |
|||
# result_text = {'code': "200", 'data': res.decode('UTF-8')} |
|||
# print(result_text) |
@ -0,0 +1,35 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/22 14:44 |
|||
# @Author:ZYP |
|||
# @File:fasttext_api.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import json |
|||
|
|||
import numpy as np |
|||
from gensim.models.keyedvectors import KeyedVectors |
|||
import time |
|||
from flask import Flask, request |
|||
|
|||
app_fasttext = Flask(__name__) |
|||
|
|||
fasttext_path = '/home/zc-nlp-zyp/work_file/ssd_data/public_data/fasttext_model/fasttext.vector' |
|||
model_fasttext = KeyedVectors.load_word2vec_format(fasttext_path, binary=True) |
|||
|
|||
|
|||
@app_fasttext.route('/fasttext', methods=['POST']) |
|||
def get_word2vec(): |
|||
word_dict = json.loads(request.data.decode()) |
|||
try: |
|||
vec = model_fasttext.get_vector(word_dict["word"]) |
|||
str_vec = ','.join([str(i) for i in vec]) |
|||
# vec1 = np.array([float(j) for j in str_vec.split(',')], dtype=np.float64) |
|||
vec_dict = {'vec': str_vec} |
|||
return json.dumps(vec_dict) |
|||
except: |
|||
return 'error_fasttext' |
|||
|
|||
|
|||
# if __name__ == '__main__': |
|||
# app.run(host='0.0.0.0', port=50002, debug=False) |
|||
|
@ -0,0 +1,25 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/22 15:30 |
|||
# @Author:ZYP |
|||
# @File:fasttext_config.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import logging |
|||
import logging.handlers |
|||
import os |
|||
import gevent.monkey |
|||
|
|||
gevent.monkey.patch_all() |
|||
|
|||
bind = '0.0.0.0:50002' # 绑定的ip已经端口号 |
|||
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext' # gunicorn要切换到的目的工作目录 |
|||
timeout = 60 # 超时 |
|||
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式 |
|||
workers = 4 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数 |
|||
threads = 4 |
|||
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置 |
|||
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置 |
|||
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext/fasttext_log/gunicorn.pid" |
|||
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext/fasttext_log/access.log" |
|||
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext/fasttext_log/error.log" |
|||
daemon = True |
@ -0,0 +1,27 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/21 14:36 |
|||
# @Author:ZYP |
|||
# @File:flask_config.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import logging |
|||
import logging.handlers |
|||
import os |
|||
import multiprocessing |
|||
import gevent.monkey |
|||
|
|||
gevent.monkey.patch_all() |
|||
|
|||
bind = '0.0.0.0:50004' # 绑定的ip已经端口号 |
|||
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo' # gunicorn要切换到的目的工作目录 |
|||
timeout = 200 # 超时 |
|||
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式 |
|||
workers = 5 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数 |
|||
threads = 1 |
|||
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置 |
|||
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置 |
|||
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/gunicornLogs/gunicorn.pid" |
|||
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/gunicornLogs/access.log" |
|||
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/gunicornLogs/error.log" |
|||
|
|||
daemon = True |
@ -0,0 +1,102 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/21 18:41 |
|||
# @Author:ZYP |
|||
# @File:mysql_collect.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import pymysql |
|||
from dbutils.pooled_db import PooledDB |
|||
|
|||
host = '192.168.31.145' |
|||
port = 3306 |
|||
user = 'root' |
|||
password = '123456' |
|||
database = 'zhiwang_db' |
|||
|
|||
|
|||
class MySQLConnectionPool: |
|||
|
|||
def __init__(self, ): |
|||
self.pool = PooledDB( |
|||
creator=pymysql, # 使用链接数据库的模块 |
|||
mincached=20, # 初始化时,链接池中至少创建的链接,0表示不创建 |
|||
maxconnections=200, # 连接池允许的最大连接数,0和None表示不限制连接数 |
|||
blocking=True, # 连接池中如果没有可用连接后,是否阻塞等待。True,等待;False,不等待然后报错 |
|||
host=host, |
|||
port=port, |
|||
user=user, |
|||
password=password, |
|||
database=database |
|||
) |
|||
|
|||
def open(self): |
|||
conn = self.pool.connection() |
|||
# self.cursor = self.conn.cursor(cursor=pymysql.cursors.DictCursor) # 表示读取的数据为字典类型 |
|||
cursor = conn.cursor() # 表示读取的数据为字典类型 |
|||
return conn, cursor |
|||
|
|||
def close(self, cursor, conn): |
|||
cursor.close() |
|||
conn.close() |
|||
|
|||
def select_one(self, sql, *args): |
|||
"""查询单条数据""" |
|||
conn, cursor = self.open() |
|||
cursor.execute(sql, args) |
|||
result = cursor.fetchone() |
|||
self.close(conn, cursor) |
|||
return result |
|||
|
|||
def select_all(self, sql, args): |
|||
"""查询多条数据""" |
|||
conn, cursor = self.open() |
|||
cursor.execute(sql, args) |
|||
result = cursor.fetchall() |
|||
self.close(conn, cursor) |
|||
return result |
|||
|
|||
def insert_one(self, sql, args): |
|||
"""插入单条数据""" |
|||
self.execute(sql, args, isNeed=True) |
|||
|
|||
def insert_all(self, sql, datas): |
|||
"""插入多条批量插入""" |
|||
conn, cursor = self.open() |
|||
try: |
|||
cursor.executemany(sql, datas) |
|||
conn.commit() |
|||
return {'result': True, 'id': int(cursor.lastrowid)} |
|||
except Exception as err: |
|||
conn.rollback() |
|||
return {'result': False, 'err': err} |
|||
|
|||
def update_one(self, sql, args): |
|||
"""更新数据""" |
|||
self.execute(sql, args, isNeed=True) |
|||
|
|||
def delete_one(self, sql, *args): |
|||
"""删除数据""" |
|||
self.execute(sql, args, isNeed=True) |
|||
|
|||
def execute(self, sql, args, isNeed=False): |
|||
""" |
|||
执行 |
|||
:param isNeed 是否需要回滚 |
|||
""" |
|||
conn, cursor = self.open() |
|||
if isNeed: |
|||
try: |
|||
cursor.execute(sql, args) |
|||
conn.commit() |
|||
except: |
|||
conn.rollback() |
|||
else: |
|||
cursor.execute(sql, args) |
|||
conn.commit() |
|||
self.close(conn, cursor) |
|||
|
|||
|
|||
mysql = MySQLConnectionPool() |
|||
# sql_select_all = 'select * from `main_table_paper_detail_message` limit %s;' |
|||
# results = mysql.select_all(sql_select_all, (1,)) |
|||
# print(results) |
@ -0,0 +1,141 @@ |
|||
# -*- coding = utf-8 -*- |
|||
# @Time: 16:41 |
|||
# @Author:ZYP |
|||
# @File:roformer_api.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
|
|||
# ========================================================================================= |
|||
# 加载深度学习模型 |
|||
# · 加载论文分类模型 |
|||
# · 加载 BERT 模型 |
|||
# ========================================================================================= |
|||
|
|||
import json |
|||
import os |
|||
import numpy as np |
|||
from bert4keras.models import build_transformer_model |
|||
from keras.layers import Lambda, Dense |
|||
from keras.models import Model |
|||
from bert4keras.tokenizers import Tokenizer |
|||
from bert4keras.backend import K |
|||
import tensorflow as tf |
|||
from keras.backend import set_session |
|||
from flask import Flask, request |
|||
|
|||
# ========================================================================================================================= |
|||
# roformer 模型的参数 |
|||
# ========================================================================================================================= |
|||
class_nums = 168 |
|||
max_len = 512 |
|||
roformer_config_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json' |
|||
roformer_ckpt_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt' |
|||
roformer_vocab_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt' |
|||
roformer_model_weights_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/model3/best_model.weights' |
|||
label_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/data/label_threshold.txt' |
|||
|
|||
tfconfig = tf.ConfigProto() |
|||
tfconfig.gpu_options.allow_growth = True |
|||
set_session(tf.Session(config=tfconfig)) # 此处不同 |
|||
global graph |
|||
graph = tf.get_default_graph() |
|||
sess = tf.Session(graph=graph) |
|||
set_session(sess) |
|||
|
|||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|||
|
|||
app_roformer = Flask(__name__) |
|||
|
|||
|
|||
def load_roformer_model(config, ckpt): |
|||
"""加载训练好的168多标签分类模型""" |
|||
roformer = build_transformer_model( |
|||
config_path=config, |
|||
checkpoint_path=ckpt, |
|||
model='roformer_v2', |
|||
return_keras_model=False) |
|||
output = Lambda(lambda x: x[:, 0])(roformer.model.output) |
|||
output = Dense( |
|||
units=class_nums, |
|||
kernel_initializer=roformer.initializer |
|||
)(output) |
|||
model1 = Model(roformer.model.input, output) |
|||
model1.summary() |
|||
return model1 |
|||
|
|||
|
|||
def load_label(label_path1): |
|||
"""加载label2id、id2label、每个类别的阈值,用于分类""" |
|||
with open(label_path1, 'r', encoding='utf-8') as f: |
|||
json_dict = json.load(f) |
|||
label2id1 = {i: j[0] for i, j in json_dict.items()} |
|||
id2label1 = {j[0]: i for i, j in json_dict.items()} |
|||
label_threshold1 = np.array([j[1] for i, j in json_dict.items()]) |
|||
return label2id1, id2label1, label_threshold1 |
|||
|
|||
|
|||
# 加载label的相关信息 |
|||
label2id, id2label, label_threshold = load_label(label_path) |
|||
# roformer 模型的分词器 |
|||
tokenizer_roformer = Tokenizer(roformer_vocab_path) |
|||
|
|||
# 加载模型 |
|||
model_roformer = load_roformer_model(roformer_config_path, roformer_ckpt_path) |
|||
set_session(sess) |
|||
|
|||
# 加载训练好的权重 |
|||
model_roformer.load_weights(roformer_model_weights_path) |
|||
|
|||
|
|||
def encode(text_list1): |
|||
"""将文本列表进行循环编码""" |
|||
sent_token_id1, sent_segment_id1 = [], [] |
|||
for index, text in enumerate(text_list1): |
|||
if index == 0: |
|||
token_id, segment_id = tokenizer_roformer.encode(text) |
|||
else: |
|||
token_id, segment_id = tokenizer_roformer.encode(text) |
|||
token_id = token_id[1:] |
|||
segment_id = segment_id[1:] |
|||
if (index + 1) % 2 == 0: |
|||
segment_id = [1] * len(token_id) |
|||
sent_token_id1 += token_id |
|||
sent_segment_id1 += segment_id |
|||
if len(sent_token_id1) > max_len: |
|||
sent_token_id1 = sent_token_id1[:max_len] |
|||
sent_segment_id1 = sent_segment_id1[:max_len] |
|||
|
|||
sent_token_id = np.array([sent_token_id1]) |
|||
sent_segment_id = np.array([sent_segment_id1]) |
|||
return sent_token_id, sent_segment_id |
|||
|
|||
|
|||
@app_roformer.route('/roformer', methods=['POST']) |
|||
def pred_class_num(): |
|||
"""将分类的预测结果进行返回,返回对应库的下标,同时对送检论文的要求处理成字典形式,包括 title、key_words、abst_zh、content 等""" |
|||
try: |
|||
target_paper_dict = json.loads(request.data.decode()) |
|||
text_list1 = [target_paper_dict['title']] # , target_paper_dict['key_words'] |
|||
abst_zh = target_paper_dict['abst_zh'] |
|||
|
|||
if len(abst_zh.split("。")) <= 10: |
|||
text_list1.append(abst_zh) |
|||
else: |
|||
text_list1.append("。".join(abst_zh.split('。')[:5])) |
|||
text_list1.append("。".join(abst_zh.split('。')[-5:])) |
|||
|
|||
sent_token, segment_ids = encode(text_list1) |
|||
|
|||
with graph.as_default(): |
|||
K.set_session(sess) |
|||
y_pred = model_roformer.predict([sent_token, segment_ids]) |
|||
|
|||
idx = np.where(y_pred[0] > label_threshold, 1, 0) |
|||
pred_label_num_dict = {'label_num': [index for index, i in enumerate(idx) if i == 1]} |
|||
return json.dumps(pred_label_num_dict) |
|||
except: |
|||
return 'error_roformer' |
|||
|
|||
|
|||
# if __name__ == '__main__': |
|||
# app_roformer.run('0.0.0.0', port=50003, debug=False) |
@ -0,0 +1,26 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/22 16:06 |
|||
# @Author:ZYP |
|||
# @File:roformer_config.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import logging |
|||
import logging.handlers |
|||
import os |
|||
import multiprocessing |
|||
import gevent.monkey |
|||
|
|||
gevent.monkey.patch_all() |
|||
|
|||
bind = '0.0.0.0:50003' # 绑定的ip已经端口号 |
|||
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer' # gunicorn要切换到的目的工作目录 |
|||
timeout = 60 # 超时 |
|||
backlog = 2048 |
|||
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式 |
|||
workers = 1 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数 |
|||
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置 |
|||
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置 |
|||
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer/roformer_log/gunicorn.pid" |
|||
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer/roformer_log/access.log" |
|||
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer/roformer_log/error.log" |
|||
daemon = True |
@ -0,0 +1,96 @@ |
|||
# -*- coding = utf-8 -*- |
|||
# @Time: 18:02 |
|||
# @Author:ZYP |
|||
# @File:util.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
# ========================================================================================= |
|||
# 工具类 |
|||
# 用于加载停用词、数据库、word2vec、fasttext模型 |
|||
# ========================================================================================= |
|||
import os |
|||
import time |
|||
import math |
|||
import json |
|||
import jieba |
|||
import numpy as np |
|||
import requests |
|||
from collections import defaultdict |
|||
from textrank4zh import TextRank4Keyword |
|||
|
|||
jieba.initialize() |
|||
|
|||
stop_word_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/fasttext_train/data/total_stopwords.txt' |
|||
|
|||
class Logging: |
|||
def __init__(self): |
|||
pass |
|||
|
|||
def log(*args, **kwargs): |
|||
format = '%Y/%m/%d-%H:%M:%S' |
|||
format_h = '%Y-%m-%d' |
|||
value = time.localtime(int(time.time())) |
|||
dt = time.strftime(format, value) |
|||
dt_log_file = time.strftime(format_h, value) |
|||
log_file = 'gunicornLogs/access-%s-%s' % (str(os.getpid()), dt_log_file) + ".log" |
|||
if not os.path.exists(log_file): |
|||
with open(os.path.join(log_file), 'w', encoding='utf-8') as f: |
|||
print(dt, *args, file=f, **kwargs) |
|||
else: |
|||
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: |
|||
print(dt, *args, file=f, **kwargs) |
|||
|
|||
|
|||
def load_stopwords(path=stop_word_path): |
|||
"""加载停用词""" |
|||
with open(path, 'r', encoding='utf-8') as f: |
|||
stop_words = {i.strip() for i in f.readlines()} |
|||
return stop_words |
|||
|
|||
|
|||
def cut_text(text_str, tokenizer='jieba'): |
|||
"""使用相应的分词算法对文章进行分词,然后统计每个单词的词频,按照降序返回相应的字典""" |
|||
word_dict = defaultdict(int) |
|||
if tokenizer == 'jieba': |
|||
all_word_list = jieba.cut(text_str) |
|||
for word in all_word_list: |
|||
if word not in stop_word: |
|||
word_dict[word] += 1 |
|||
# elif tokenizer == 'hanlp': |
|||
# for i in HanLP.segment(text_str): |
|||
# if i.word not in stop_word and i.nature != 'w': |
|||
# word_dict[i.word] += 1 |
|||
else: |
|||
print('您输入的 tokenizer 参数有误!') |
|||
|
|||
return {k: v for k, v in sorted(word_dict.items(), key=lambda x: x[1], reverse=True)} |
|||
|
|||
|
|||
def l2_normal(tf_idf_dict): |
|||
"""对计算出来的tf-idf字典进行归一化,归一到(0-1)之间""" |
|||
l2_norm = math.sqrt(sum(map(lambda x: x ** 2, tf_idf_dict.values()))) |
|||
tf_idf_dict1 = sorted(tf_idf_dict.items(), key=lambda x: x[1] / l2_norm, reverse=True) |
|||
tf_idf_dict2 = {key: value / l2_norm for key, value in tf_idf_dict1[:15]} |
|||
return tf_idf_dict2 |
|||
|
|||
|
|||
def get_word_vec(word): |
|||
"""根据相应的词语,使用模型进行提取词语向量,如果不存在词表中返回0,存在词表中返回对应向量""" |
|||
vec = requests.post('http://192.168.31.74:50001/word2vec', data=json.dumps({'word': word}), timeout=100) |
|||
if len(vec.text) < 100: |
|||
vec = requests.post('http://192.168.31.74:50002/fasttext', data=json.dumps({'word': word}), timeout=100) |
|||
if len(vec.text) < 100: |
|||
vec = 0 |
|||
return vec |
|||
else: |
|||
json_dict = json.loads(vec.text) |
|||
res_vec = np.array([float(j) for j in json_dict["vec"].split(',')], dtype=np.float64) |
|||
return res_vec |
|||
else: |
|||
json_dict = json.loads(vec.text) |
|||
res_vec = np.array([float(j) for j in json_dict["vec"].split(',')], dtype=np.float64) |
|||
return res_vec |
|||
|
|||
|
|||
stop_word = load_stopwords() |
|||
tr4w = TextRank4Keyword(stop_words_file=stop_word_path) |
@ -0,0 +1,34 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/22 14:44 |
|||
# @Author:ZYP |
|||
# @File:word2vec_api.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import json |
|||
|
|||
import numpy as np |
|||
from gensim.models.keyedvectors import KeyedVectors |
|||
import time |
|||
from flask import Flask, request |
|||
|
|||
app_word2vec = Flask(__name__) |
|||
|
|||
word2vec_path = "/home/zc-nlp-zyp/work_file/ssd_data/public_data/word2vec_model/word2vec.vector" |
|||
model_word2vec = KeyedVectors.load_word2vec_format(word2vec_path, binary=True) |
|||
|
|||
|
|||
@app_word2vec.route('/word2vec', methods=['POST']) |
|||
def get_word2vec(): |
|||
word_dict = json.loads(request.data.decode()) |
|||
try: |
|||
vec = model_word2vec.get_vector(word_dict["word"]) |
|||
str_vec = ','.join([str(i) for i in vec]) |
|||
# vec1 = np.array([float(j) for j in str_vec.split(',')], dtype=np.float64) |
|||
vec_dict = {'vec': str_vec} |
|||
return json.dumps(vec_dict) |
|||
except: |
|||
return 'error_word2vec' |
|||
|
|||
|
|||
# if __name__ == '__main__': |
|||
# app.run(host='0.0.0.0', port=50001, debug=False) |
@ -0,0 +1,26 @@ |
|||
# -*- coding:utf-8 -*- |
|||
# @Time: 2023/8/22 15:30 |
|||
# @Author:ZYP |
|||
# @File:word2vec_config.py |
|||
# @mail:zypsunshine1@gmail.com |
|||
# @Software: PyCharm |
|||
import logging |
|||
import logging.handlers |
|||
import os |
|||
import multiprocessing |
|||
import gevent.monkey |
|||
|
|||
gevent.monkey.patch_all() |
|||
|
|||
bind = '0.0.0.0:50001' # 绑定的ip已经端口号 |
|||
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec' # gunicorn要切换到的目的工作目录 |
|||
timeout = 60 # 超时 |
|||
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式 |
|||
workers = 4 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数 |
|||
threads = 4 |
|||
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置 |
|||
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置 |
|||
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec/word2vec_log/gunicorn.pid" |
|||
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec/word2vec_log/access.log" |
|||
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec/word2vec_log/error.log" |
|||
daemon = True |
Loading…
Reference in new issue