Browse Source

20230831_查重

master
zhangyupeng 2 years ago
parent
commit
ded2da121b
  1. 52
      check_version_1_0/CheckFalsk.py
  2. 254
      check_version_1_0/SearchSimPaper.py
  3. 85
      check_version_1_0/demo01_multiprocessimg.py
  4. 110
      check_version_1_0/demo01_test_redis.py
  5. 35
      check_version_1_0/fasttext/fasttext_api.py
  6. 25
      check_version_1_0/fasttext/fasttext_config.py
  7. 27
      check_version_1_0/flask_config.py
  8. 102
      check_version_1_0/mysql_collect.py
  9. 141
      check_version_1_0/roformer/roformer_api.py
  10. 26
      check_version_1_0/roformer/roformer_config.py
  11. 96
      check_version_1_0/util.py
  12. 34
      check_version_1_0/word2vec/word2vec_api.py
  13. 26
      check_version_1_0/word2vec/word2vec_config.py

52
check_version_1_0/CheckFalsk.py

File diff suppressed because one or more lines are too long

254
check_version_1_0/SearchSimPaper.py

@ -0,0 +1,254 @@
# -*- coding = utf-8 -*-
# @Time: 18:01
# @Author:ZYP
# @File:SearchSimPaper.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import gc
import time
# =========================================================================================
# 查找相似文档
# · 文档之间关键词进行取交集
# · 再对选取出的文档与送检文档进行关键词之间的相似度计算
# · 最终选出最相似的文档,进行排序返回
# =========================================================================================
import math
import numpy as np
from collections import defaultdict
from pymysql.converters import escape_string
from sklearn.metrics.pairwise import cosine_similarity
from util import cut_text, l2_normal, get_word_vec
def load_inverted_table(class_list, mysql, log):
"""根据分类结果,将每个类别的倒排表进行聚合,返回一个几个类别的字典、几个类别库中总论文数量"""
# 记录总的倒排表 {word:[doc_id1,doc_id2,doc_id3, ...]}
total_inverted_dict1 = {}
# 记录每个类别的论文数量的和
total_nums1 = 0
for label_num in class_list:
conn, cursor = mysql.open()
select_sql = """
select word, paper_doc_id from word_map_paper_{};
""".format(str(label_num))
s_time1 = time.time()
cursor.execute(select_sql)
for word, paper_doc_id in cursor.fetchall():
if word not in total_inverted_dict1.keys():
total_inverted_dict1[word] = paper_doc_id
else:
# total_inverted_dict1[word] = ','.join(
# set(total_inverted_dict1[word].split(',') + paper_doc_id.split(',')))
total_inverted_dict1[word] = total_inverted_dict1[word] + ',' + paper_doc_id
e_time1 = time.time()
log.log('查找{}类倒排表花费的时间:{}s'.format(str(label_num), e_time1 - s_time1))
s_time2 = time.time()
select_paper_num_sql = """
select count_number from count_map where label_num={};
""".format(str(label_num))
cursor.execute(select_paper_num_sql)
for nums in cursor.fetchall():
total_nums1 += int(nums[0])
e_time2 = time.time()
log.log('查找{}类别下数据量花费:{}s'.format(str(label_num), e_time2 - s_time2))
mysql.close(cursor, conn)
return total_inverted_dict1, total_nums1
def select_sim_doc_message(sim_doc1, mysql):
"""
通过相似的 doc_id 在库中查找相关的信息然后计算每个 doc_id 的均值文档向量以字典形式返回 {文档号均值文档向量....}
:param sim_doc1: 相似文档的列表[doc_id1, doc_id2, ...]
:return: 返回{doc_id:(doc_avg_vec, doc_path)}
"""
all_paper_vec_dict = {}
conn, cursor = mysql.open()
for doc_id in sim_doc1:
select_sql = """
select tb1.doc_id, tb1.title, tb1.abst_zh, tb2.vsm, tb1.content from
(
(select doc_id, title, abst_zh, content from main_table_paper_detail_message) tb1
left join
(select doc_id, vsm from id_keywords_weights) tb2
on
tb1.doc_id=tb2.doc_id
)where tb1.doc_id="{}";
""".format(
escape_string(doc_id))
cursor.execute(select_sql)
sim_doc_id, sim_title, sim_abst, sim_vsm, sim_content_path = cursor.fetchone()
sim_vsm_dict = {weight.split('@#$@')[0]: float(weight.split('@#$@')[1]) for weight in sim_vsm.split('&*^%')}
vector_paper = []
value_sum = 0.0
for word, weight in sim_vsm_dict.items():
if word in sim_title:
value = 0.5 * weight
elif word in sim_abst:
value = 0.3 * weight
else:
value = 0.2 * weight
word_vec = get_word_vec(word)
if isinstance(word_vec, int):
continue
vector_paper.append(word_vec * value)
value_sum += value
del sim_vsm_dict
gc.collect()
# 求一篇文档的关键词的向量均值
# avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / len(vector_paper))
avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / value_sum)
all_paper_vec_dict[doc_id] = (avg_vector, sim_content_path)
mysql.close(cursor, conn)
return all_paper_vec_dict
def submit_paper_avg_vec(paper_dict1, tf_weight_dict):
"""根据送检的文档的 tf 值,计算这篇文档的均值向量,以 numpy 数组形式返回"""
vector_paper = []
value_sum = 0.0
for word, weight in tf_weight_dict.items():
if word in paper_dict1['title']:
value = 0.5 * weight
elif word in paper_dict1['abst_zh']:
value = 0.3 * weight
else:
value = 0.2 * weight
word_vec = get_word_vec(word)
if isinstance(word_vec, int):
continue
vector_paper.append(word_vec * value)
value_sum += value
# avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / len(vector_paper))
avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / value_sum)
return avg_vector
def compare_sim_in_papers(check_vector, sim_message, top=40):
"""
计算文档间的相似度,使用的是余弦相似度
:param check_vector: 送检文章的文本向量
:param sim_message: 待检测的 50 篇相似文档,以字典形式存储
:param top: 设置返回最相似的 N 篇文档
:return: 返回相似文档的字典 形式{doc_id:(相似得分, 文档路径)}
"""
sim_res_dict = {}
for doc_id, (vector, content_path) in sim_message.items():
# sim_res_dict[doc_id] = cosine_similarity([scalar(check_vector), scalar(vector)])[0][1]
sim_res_dict[doc_id] = (str(cosine_similarity([check_vector, vector])[0][1]), content_path)
_ = sorted(sim_res_dict.items(), key=lambda x: float(x[1][0]), reverse=True)
return {key: value for key, value in _[:top]}
def search_sim_paper(paper_dict, class_list, mysql, log, top=100):
"""
根据送检论文的字典在库中进行相似文档的查询最后返回最相似的 top 文章用于逐句查重
:param paper_dict: 处理好的格式化送检论文
:param class_list: 模型预测送检论文的类别 id 的列表
:param top: 返回前 top 个文档
:return: 返回相似文档的字典 形式{doc_id:(相似得分, 文档路径)}
"""
all_str = paper_dict['title'] + '' + paper_dict['abst_zh'] + '' + paper_dict['content']
# 合并倒排表,并统计 论文总量 total_inverted_dict:总的倒排表
s0 = time.time()
total_inverted_dict, total_nums = load_inverted_table(class_list, mysql, log)
e0 = time.time()
log.log('查询倒排表花费时间为:{}s'.format(e0 - s0))
s1 = time.time()
# 计算送检文档的词频字典{word1:fre1, word2:fre2, ...}
word_fre_dict = cut_text(all_str, tokenizer='jieba')
e1 = time.time()
log.log('切词时间为:{}s'.format(e1 - s1))
s2 = time.time()
# 计算送检文档所有词语的 tf-idf 值
tf_idf_dict = {}
for word, freq in word_fre_dict.items():
if freq <= 2:
continue
tf = freq / sum(word_fre_dict.values())
if word in total_inverted_dict.keys():
idf = math.log(total_nums / (len(set(total_inverted_dict[word].split(','))) + 1))
else:
idf = math.log(total_nums / 1)
tf_idf = tf * idf
tf_idf_dict[word] = tf_idf
e2 = time.time()
log.log('计算送检文档关键词的TF-idf值花费的时间为:{}s'.format(e2 - s2))
s3 = time.time()
# 前 15 的单词、权重
tf_dict = l2_normal(tf_idf_dict)
e3 = time.time()
log.log('权重正则化花费的时间为:{}s'.format(e3 - s3))
s4 = time.time()
# 统计交集的
count_words_num = defaultdict(int)
for word, weight in tf_dict.items():
if word in total_inverted_dict.keys():
for doc_id in set(total_inverted_dict[word].split(',')):
count_words_num[doc_id] += 1
else:
continue
e4 = time.time()
log.log('统计doc_id交集花费的时间为:{}s'.format(e4 - s4))
# 排序
count_word_num = {i: j for i, j in sorted(count_words_num.items(), key=lambda x: x[1], reverse=True)}
# 查找前 200 篇相似的文档
sim_doc = list(count_word_num.keys())[:200]
# 计算这 200 篇文档的 文档均值向量
s_time1 = time.time()
sim_paper_vec_dict = select_sim_doc_message(sim_doc, mysql)
e_time1 = time.time()
log.log('计算200篇均值向量所花费的时间为:{}s'.format(e_time1 - s_time1))
# 计算送检文档的 文档均值向量
s_time2 = time.time()
submit_vec = submit_paper_avg_vec(paper_dict, tf_dict)
e_time2 = time.time()
log.log('计算送检文档的均值向量所花费的时间为:{}s'.format(e_time2 - s_time2))
# 计算送检文档 和 查出来的文档的相似度 并排序, 取 top 10 文章用作整篇查重
s_time3 = time.time()
sim_paper_dict = compare_sim_in_papers(submit_vec, sim_paper_vec_dict, top=top)
e_time3 = time.time()
log.log('计算送检文档和查出的文档的相似度(并排序)所花费的时间为:{}s'.format(e_time3 - s_time3))
del total_inverted_dict
del total_nums
del submit_vec
del sim_paper_vec_dict
del count_word_num
del sim_doc
del word_fre_dict
gc.collect()
return sim_paper_dict

85
check_version_1_0/demo01_multiprocessimg.py

@ -0,0 +1,85 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/29 18:58
# @Author:ZYP
# @File:demo01_multiprocessimg.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import json
import multiprocessing
import os
import time
import uuid
import signal
from util import Logging
import redis
import requests
from SearchSimPaper import search_sim_paper
from mysql_collect import mysql
pool = redis.ConnectionPool(host='192.168.31.145', port=63179, max_connections=50, password='zhicheng123*', db=8)
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
log = Logging()
def check_main_func(uuid_):
while True:
if redis_.llen(db_key_query) == 0:
continue
while True:
ss = redis_.rpop(db_key_query)
if ss is None:
time.sleep(2)
else:
break
paper_data = json.loads(ss.decode())
id_ = paper_data["id"]
message_dict = paper_data["paper"]
class_res = \
json.loads(requests.post('http://192.168.31.145:50003/roformer', data=json.dumps(message_dict)).text)[
'label_num']
# class_res = [117, 36, 81]
sim_paper_id_dict = search_sim_paper(message_dict, class_res, mysql, log)
redis_.set(id_, json.dumps(sim_paper_id_dict))
pid = redis_.hget('process_pid', uuid_)
log.log("这个进程的 uuid 为:", uuid_)
redis_.hdel("process_pid", uuid_)
os.kill(int(pid), signal.SIGTERM)
break
# def set_process():
# name = str(uuid.uuid1())
# process = multiprocessing.Process(target=check_main_func, args=(), name=name)
# process.start()
# process.join()
# return name, process
if __name__ == '__main__':
# while redis_.llen('process_num') < 4:
# name, process = set_process()
# redis_.lpush('process_num', name)
# process.is_alive()
# pool = multiprocessing.Pool(processes=4)
# while True:
# if redis_.llen('process_num') < 4:
# redis_.lpush('process_num', '1')
# pool.apply_async(check_main_func, args=())
# pool = multiprocessing.Pool(processes=4)
while True:
if redis_.hlen('process_pid') < 4:
uuid_ = str(uuid.uuid1())
process = multiprocessing.Process(target=check_main_func, args=(uuid_,))
process.start()
process_id = process.pid
redis_.hset('process_pid', uuid_, str(process_id))

110
check_version_1_0/demo01_test_redis.py

@ -0,0 +1,110 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/28 15:03
# @Author:ZYP
# @File:demo01_test_redis.py
# @mail:zypsunshine1@gmail.com
# # @Software: PyCharm
import time
import flask
import redis
import uuid
import json
import requests
# from util import Logging
# from threading import Thread
# from mysql_collect import mysql
# from SearchSimPaper import search_sim_paper
# import jieba
from flask import request
# import multiprocessing
app_check = flask.Flask(__name__)
pool = redis.ConnectionPool(host='192.168.31.145', port=63179, max_connections=50, password='zhicheng123*', db=8)
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
# pool1 = redis.ConnectionPool(host='192.168.31.145', port=63179, max_connections=50, password='zhicheng123*', db=11)
# redis_1 = redis.Redis(connection_pool=pool1, decode_responses=True)
# jieba.initialize()
db_key_query = 'query'
# db_key_result = 'result'
# log = Logging()
# def check_main_func():
# while True:
# if redis_.llen(db_key_query) == 0:
# continue
# while True:
# ss = redis_.rpop(db_key_query)
# if ss is None:
# time.sleep(2)
# else:
# break
#
# paper_data = json.loads(ss.decode())
# id_ = paper_data["id"]
# message_dict = paper_data["paper"]
# class_res = \
# json.loads(requests.post('http://192.168.31.145:50003/roformer', data=json.dumps(message_dict)).text)[
# 'label_num']
#
# sim_paper_id_dict = search_sim_paper(message_dict, class_res, mysql, log)
# redis_.set(id_, json.dumps(sim_paper_id_dict))
@app_check.route("/check", methods=["POST"])
def handle_query():
s = time.time()
message_dict = json.loads(request.data.decode())
uuid_request = str(message_dict['uuid'])
id_ = str(uuid.uuid1()) # 为query生成唯一标识
d = {'id': id_, 'paper': message_dict} # 绑定文本和query id
redis_.rpush(db_key_query, json.dumps(d))
while True:
result = redis_.get(id_)
if result is not None:
redis_.delete(id_)
result_text = {'uuid': uuid_request, 'data': result.decode('UTF-8')}
# result_text = json.loads(result.decode('UTF-8'))
break
e = time.time()
print('{} 花费了{} s 的时间'.format(uuid_request, (e - s)))
redis_.lpush('query_recall', json.dumps(result_text))
return uuid_request # 返回结果
# # return '1'
if __name__ == "__main__":
# for i in range(2):
# t = Thread(target=check_main_func, args=())
# t.start()
# processes = []
#
# # 创建并启动多个进程
# for i in range(2):
# process = multiprocessing.Process(target=check_main_func, args=())
# processes.append(process)
# process.start()
app_check.run(debug=False, host='0.0.0.0', port=50004)
# res = redis_.rpop(db_key_query)
# print(res)
# id_ = "51bc72dc-464e-11ee-baf3-45147420c4fb"
# res = redis_.get(id_)
# if res is not None:
# redis_.delete(id_)
# result_text = {'code': "200", 'data': res.decode('UTF-8')}
# print(result_text)

35
check_version_1_0/fasttext/fasttext_api.py

@ -0,0 +1,35 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/22 14:44
# @Author:ZYP
# @File:fasttext_api.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import json
import numpy as np
from gensim.models.keyedvectors import KeyedVectors
import time
from flask import Flask, request
app_fasttext = Flask(__name__)
fasttext_path = '/home/zc-nlp-zyp/work_file/ssd_data/public_data/fasttext_model/fasttext.vector'
model_fasttext = KeyedVectors.load_word2vec_format(fasttext_path, binary=True)
@app_fasttext.route('/fasttext', methods=['POST'])
def get_word2vec():
word_dict = json.loads(request.data.decode())
try:
vec = model_fasttext.get_vector(word_dict["word"])
str_vec = ','.join([str(i) for i in vec])
# vec1 = np.array([float(j) for j in str_vec.split(',')], dtype=np.float64)
vec_dict = {'vec': str_vec}
return json.dumps(vec_dict)
except:
return 'error_fasttext'
# if __name__ == '__main__':
# app.run(host='0.0.0.0', port=50002, debug=False)

25
check_version_1_0/fasttext/fasttext_config.py

@ -0,0 +1,25 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/22 15:30
# @Author:ZYP
# @File:fasttext_config.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import logging
import logging.handlers
import os
import gevent.monkey
gevent.monkey.patch_all()
bind = '0.0.0.0:50002' # 绑定的ip已经端口号
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext' # gunicorn要切换到的目的工作目录
timeout = 60 # 超时
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式
workers = 4 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数
threads = 4
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext/fasttext_log/gunicorn.pid"
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext/fasttext_log/access.log"
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/fasttext/fasttext_log/error.log"
daemon = True

27
check_version_1_0/flask_config.py

@ -0,0 +1,27 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/21 14:36
# @Author:ZYP
# @File:flask_config.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import logging
import logging.handlers
import os
import multiprocessing
import gevent.monkey
gevent.monkey.patch_all()
bind = '0.0.0.0:50004' # 绑定的ip已经端口号
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo' # gunicorn要切换到的目的工作目录
timeout = 200 # 超时
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式
workers = 5 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数
threads = 1
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/gunicornLogs/gunicorn.pid"
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/gunicornLogs/access.log"
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/gunicornLogs/error.log"
daemon = True

102
check_version_1_0/mysql_collect.py

@ -0,0 +1,102 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/21 18:41
# @Author:ZYP
# @File:mysql_collect.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import pymysql
from dbutils.pooled_db import PooledDB
host = '192.168.31.145'
port = 3306
user = 'root'
password = '123456'
database = 'zhiwang_db'
class MySQLConnectionPool:
def __init__(self, ):
self.pool = PooledDB(
creator=pymysql, # 使用链接数据库的模块
mincached=20, # 初始化时,链接池中至少创建的链接,0表示不创建
maxconnections=200, # 连接池允许的最大连接数,0和None表示不限制连接数
blocking=True, # 连接池中如果没有可用连接后,是否阻塞等待。True,等待;False,不等待然后报错
host=host,
port=port,
user=user,
password=password,
database=database
)
def open(self):
conn = self.pool.connection()
# self.cursor = self.conn.cursor(cursor=pymysql.cursors.DictCursor) # 表示读取的数据为字典类型
cursor = conn.cursor() # 表示读取的数据为字典类型
return conn, cursor
def close(self, cursor, conn):
cursor.close()
conn.close()
def select_one(self, sql, *args):
"""查询单条数据"""
conn, cursor = self.open()
cursor.execute(sql, args)
result = cursor.fetchone()
self.close(conn, cursor)
return result
def select_all(self, sql, args):
"""查询多条数据"""
conn, cursor = self.open()
cursor.execute(sql, args)
result = cursor.fetchall()
self.close(conn, cursor)
return result
def insert_one(self, sql, args):
"""插入单条数据"""
self.execute(sql, args, isNeed=True)
def insert_all(self, sql, datas):
"""插入多条批量插入"""
conn, cursor = self.open()
try:
cursor.executemany(sql, datas)
conn.commit()
return {'result': True, 'id': int(cursor.lastrowid)}
except Exception as err:
conn.rollback()
return {'result': False, 'err': err}
def update_one(self, sql, args):
"""更新数据"""
self.execute(sql, args, isNeed=True)
def delete_one(self, sql, *args):
"""删除数据"""
self.execute(sql, args, isNeed=True)
def execute(self, sql, args, isNeed=False):
"""
执行
:param isNeed 是否需要回滚
"""
conn, cursor = self.open()
if isNeed:
try:
cursor.execute(sql, args)
conn.commit()
except:
conn.rollback()
else:
cursor.execute(sql, args)
conn.commit()
self.close(conn, cursor)
mysql = MySQLConnectionPool()
# sql_select_all = 'select * from `main_table_paper_detail_message` limit %s;'
# results = mysql.select_all(sql_select_all, (1,))
# print(results)

141
check_version_1_0/roformer/roformer_api.py

@ -0,0 +1,141 @@
# -*- coding = utf-8 -*-
# @Time: 16:41
# @Author:ZYP
# @File:roformer_api.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
# =========================================================================================
# 加载深度学习模型
# · 加载论文分类模型
# · 加载 BERT 模型
# =========================================================================================
import json
import os
import numpy as np
from bert4keras.models import build_transformer_model
from keras.layers import Lambda, Dense
from keras.models import Model
from bert4keras.tokenizers import Tokenizer
from bert4keras.backend import K
import tensorflow as tf
from keras.backend import set_session
from flask import Flask, request
# =========================================================================================================================
# roformer 模型的参数
# =========================================================================================================================
class_nums = 168
max_len = 512
roformer_config_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json'
roformer_ckpt_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt'
roformer_vocab_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt'
roformer_model_weights_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/model3/best_model.weights'
label_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/data/label_threshold.txt'
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
set_session(tf.Session(config=tfconfig)) # 此处不同
global graph
graph = tf.get_default_graph()
sess = tf.Session(graph=graph)
set_session(sess)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
app_roformer = Flask(__name__)
def load_roformer_model(config, ckpt):
"""加载训练好的168多标签分类模型"""
roformer = build_transformer_model(
config_path=config,
checkpoint_path=ckpt,
model='roformer_v2',
return_keras_model=False)
output = Lambda(lambda x: x[:, 0])(roformer.model.output)
output = Dense(
units=class_nums,
kernel_initializer=roformer.initializer
)(output)
model1 = Model(roformer.model.input, output)
model1.summary()
return model1
def load_label(label_path1):
"""加载label2id、id2label、每个类别的阈值,用于分类"""
with open(label_path1, 'r', encoding='utf-8') as f:
json_dict = json.load(f)
label2id1 = {i: j[0] for i, j in json_dict.items()}
id2label1 = {j[0]: i for i, j in json_dict.items()}
label_threshold1 = np.array([j[1] for i, j in json_dict.items()])
return label2id1, id2label1, label_threshold1
# 加载label的相关信息
label2id, id2label, label_threshold = load_label(label_path)
# roformer 模型的分词器
tokenizer_roformer = Tokenizer(roformer_vocab_path)
# 加载模型
model_roformer = load_roformer_model(roformer_config_path, roformer_ckpt_path)
set_session(sess)
# 加载训练好的权重
model_roformer.load_weights(roformer_model_weights_path)
def encode(text_list1):
"""将文本列表进行循环编码"""
sent_token_id1, sent_segment_id1 = [], []
for index, text in enumerate(text_list1):
if index == 0:
token_id, segment_id = tokenizer_roformer.encode(text)
else:
token_id, segment_id = tokenizer_roformer.encode(text)
token_id = token_id[1:]
segment_id = segment_id[1:]
if (index + 1) % 2 == 0:
segment_id = [1] * len(token_id)
sent_token_id1 += token_id
sent_segment_id1 += segment_id
if len(sent_token_id1) > max_len:
sent_token_id1 = sent_token_id1[:max_len]
sent_segment_id1 = sent_segment_id1[:max_len]
sent_token_id = np.array([sent_token_id1])
sent_segment_id = np.array([sent_segment_id1])
return sent_token_id, sent_segment_id
@app_roformer.route('/roformer', methods=['POST'])
def pred_class_num():
"""将分类的预测结果进行返回,返回对应库的下标,同时对送检论文的要求处理成字典形式,包括 title、key_words、abst_zh、content 等"""
try:
target_paper_dict = json.loads(request.data.decode())
text_list1 = [target_paper_dict['title']] # , target_paper_dict['key_words']
abst_zh = target_paper_dict['abst_zh']
if len(abst_zh.split("")) <= 10:
text_list1.append(abst_zh)
else:
text_list1.append("".join(abst_zh.split('')[:5]))
text_list1.append("".join(abst_zh.split('')[-5:]))
sent_token, segment_ids = encode(text_list1)
with graph.as_default():
K.set_session(sess)
y_pred = model_roformer.predict([sent_token, segment_ids])
idx = np.where(y_pred[0] > label_threshold, 1, 0)
pred_label_num_dict = {'label_num': [index for index, i in enumerate(idx) if i == 1]}
return json.dumps(pred_label_num_dict)
except:
return 'error_roformer'
# if __name__ == '__main__':
# app_roformer.run('0.0.0.0', port=50003, debug=False)

26
check_version_1_0/roformer/roformer_config.py

@ -0,0 +1,26 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/22 16:06
# @Author:ZYP
# @File:roformer_config.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import logging
import logging.handlers
import os
import multiprocessing
import gevent.monkey
gevent.monkey.patch_all()
bind = '0.0.0.0:50003' # 绑定的ip已经端口号
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer' # gunicorn要切换到的目的工作目录
timeout = 60 # 超时
backlog = 2048
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式
workers = 1 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer/roformer_log/gunicorn.pid"
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer/roformer_log/access.log"
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/roformer/roformer_log/error.log"
daemon = True

96
check_version_1_0/util.py

@ -0,0 +1,96 @@
# -*- coding = utf-8 -*-
# @Time: 18:02
# @Author:ZYP
# @File:util.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
# =========================================================================================
# 工具类
# 用于加载停用词、数据库、word2vec、fasttext模型
# =========================================================================================
import os
import time
import math
import json
import jieba
import numpy as np
import requests
from collections import defaultdict
from textrank4zh import TextRank4Keyword
jieba.initialize()
stop_word_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/fasttext_train/data/total_stopwords.txt'
class Logging:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'gunicornLogs/access-%s-%s' % (str(os.getpid()), dt_log_file) + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def load_stopwords(path=stop_word_path):
"""加载停用词"""
with open(path, 'r', encoding='utf-8') as f:
stop_words = {i.strip() for i in f.readlines()}
return stop_words
def cut_text(text_str, tokenizer='jieba'):
"""使用相应的分词算法对文章进行分词,然后统计每个单词的词频,按照降序返回相应的字典"""
word_dict = defaultdict(int)
if tokenizer == 'jieba':
all_word_list = jieba.cut(text_str)
for word in all_word_list:
if word not in stop_word:
word_dict[word] += 1
# elif tokenizer == 'hanlp':
# for i in HanLP.segment(text_str):
# if i.word not in stop_word and i.nature != 'w':
# word_dict[i.word] += 1
else:
print('您输入的 tokenizer 参数有误!')
return {k: v for k, v in sorted(word_dict.items(), key=lambda x: x[1], reverse=True)}
def l2_normal(tf_idf_dict):
"""对计算出来的tf-idf字典进行归一化,归一到(0-1)之间"""
l2_norm = math.sqrt(sum(map(lambda x: x ** 2, tf_idf_dict.values())))
tf_idf_dict1 = sorted(tf_idf_dict.items(), key=lambda x: x[1] / l2_norm, reverse=True)
tf_idf_dict2 = {key: value / l2_norm for key, value in tf_idf_dict1[:15]}
return tf_idf_dict2
def get_word_vec(word):
"""根据相应的词语,使用模型进行提取词语向量,如果不存在词表中返回0,存在词表中返回对应向量"""
vec = requests.post('http://192.168.31.74:50001/word2vec', data=json.dumps({'word': word}), timeout=100)
if len(vec.text) < 100:
vec = requests.post('http://192.168.31.74:50002/fasttext', data=json.dumps({'word': word}), timeout=100)
if len(vec.text) < 100:
vec = 0
return vec
else:
json_dict = json.loads(vec.text)
res_vec = np.array([float(j) for j in json_dict["vec"].split(',')], dtype=np.float64)
return res_vec
else:
json_dict = json.loads(vec.text)
res_vec = np.array([float(j) for j in json_dict["vec"].split(',')], dtype=np.float64)
return res_vec
stop_word = load_stopwords()
tr4w = TextRank4Keyword(stop_words_file=stop_word_path)

34
check_version_1_0/word2vec/word2vec_api.py

@ -0,0 +1,34 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/22 14:44
# @Author:ZYP
# @File:word2vec_api.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import json
import numpy as np
from gensim.models.keyedvectors import KeyedVectors
import time
from flask import Flask, request
app_word2vec = Flask(__name__)
word2vec_path = "/home/zc-nlp-zyp/work_file/ssd_data/public_data/word2vec_model/word2vec.vector"
model_word2vec = KeyedVectors.load_word2vec_format(word2vec_path, binary=True)
@app_word2vec.route('/word2vec', methods=['POST'])
def get_word2vec():
word_dict = json.loads(request.data.decode())
try:
vec = model_word2vec.get_vector(word_dict["word"])
str_vec = ','.join([str(i) for i in vec])
# vec1 = np.array([float(j) for j in str_vec.split(',')], dtype=np.float64)
vec_dict = {'vec': str_vec}
return json.dumps(vec_dict)
except:
return 'error_word2vec'
# if __name__ == '__main__':
# app.run(host='0.0.0.0', port=50001, debug=False)

26
check_version_1_0/word2vec/word2vec_config.py

@ -0,0 +1,26 @@
# -*- coding:utf-8 -*-
# @Time: 2023/8/22 15:30
# @Author:ZYP
# @File:word2vec_config.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
import logging
import logging.handlers
import os
import multiprocessing
import gevent.monkey
gevent.monkey.patch_all()
bind = '0.0.0.0:50001' # 绑定的ip已经端口号
chdir = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec' # gunicorn要切换到的目的工作目录
timeout = 60 # 超时
worker_class = 'gevent' # 使用gevent模式,还可以使用sync 模式,默认的是sync模式
workers = 4 # multiprocessing.cpu_count() * 2 + 1 # 启动的进程数
threads = 4
loglevel = "info" # 日志级别,这个日志级别指的是错误日志的级别,而访问日志的级别无法设置
access_log_format = '%(t)s %(p)s %(h)s "%(r)s" %(s)s %(L)s %(b)s %(f)s" "%(a)s"' # 设置gunicorn访问日志格式,错误日志无法设置
pidfile = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec/word2vec_log/gunicorn.pid"
accesslog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec/word2vec_log/access.log"
errorlog = "/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/check1/change_demo/word2vec/word2vec_log/error.log"
daemon = True
Loading…
Cancel
Save