You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.2 KiB
147 lines
5.2 KiB
2 years ago
|
# -*- coding = utf-8 -*-
|
||
|
# @Time: 18:02
|
||
|
# @Author:ZYP
|
||
|
# @File:util.py
|
||
|
# @mail:zypsunshine1@gmail.com
|
||
|
# @Software: PyCharm
|
||
|
|
||
|
# =========================================================================================
|
||
|
# 工具类
|
||
|
# 用于加载停用词、数据库、word2vec、fasttext模型
|
||
|
# =========================================================================================
|
||
|
|
||
|
|
||
|
import os
|
||
|
import math
|
||
|
import jieba
|
||
|
import pymysql
|
||
|
from pyhanlp import HanLP
|
||
|
from collections import defaultdict
|
||
|
from textrank4zh import TextRank4Keyword
|
||
|
from gensim.models.keyedvectors import KeyedVectors
|
||
|
|
||
|
stop_word_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/fasttext_train/data/total_stopwords.txt'
|
||
|
|
||
|
jieba.load_userdict('/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/fasttext_train/data'
|
||
|
'/user_dict_final_230316.txt')
|
||
|
|
||
|
os.environ['JAVA_HOME'] = '/home/zc-nlp-zyp/work_file/software/jdk1.8.0_341'
|
||
|
|
||
|
|
||
|
def deal_paper(target_paper_path):
|
||
|
"""根据不同格式的论文进行相应的清洗策略,将杂乱的文本处理成字典形式,分为 题目、摘要、正文 等,然后返回字典格式"""
|
||
|
paper_dict = {}
|
||
|
"""
|
||
|
具体的清洗策略,具体情况、具体分析(清洗等)
|
||
|
"""
|
||
|
return paper_dict
|
||
|
|
||
|
|
||
|
class MysqlConnect:
|
||
|
"""mysql 的连接类,创建 mysql 连接对象"""
|
||
|
|
||
|
def __init__(self, host='localhost', user='root', passwd='123456', database='zhiwang_class_db', charset='utf8'):
|
||
|
self.conn = pymysql.connect(host=host, user=user, passwd=passwd, database=database, charset=charset)
|
||
|
self.cursor = self.conn.cursor()
|
||
|
|
||
|
def implement_sql(self, sql, is_close=True):
|
||
|
"""向sql中插入数据,查询完成后关闭连接"""
|
||
|
self.cursor.execute(sql)
|
||
|
self.conn.commit()
|
||
|
if is_close:
|
||
|
self.cursor.close()
|
||
|
self.conn.close()
|
||
|
|
||
|
def select_sql(self, sql, is_close=True):
|
||
|
"""向sql中插入数据,查询完成后关闭连接"""
|
||
|
self.cursor.execute(sql)
|
||
|
res = [i for i in self.cursor.fetchall()]
|
||
|
if is_close:
|
||
|
self.cursor.close()
|
||
|
self.conn.close()
|
||
|
return res
|
||
|
|
||
|
def close_connect(self):
|
||
|
self.cursor.close()
|
||
|
self.conn.close()
|
||
|
|
||
|
|
||
|
def load_stopwords(path=stop_word_path):
|
||
|
"""加载停用词"""
|
||
|
with open(path, 'r', encoding='utf-8') as f:
|
||
|
stop_words = {i.strip() for i in f.readlines()}
|
||
|
return stop_words
|
||
|
|
||
|
|
||
|
def cut_text(text_str, tokenizer='jieba'):
|
||
|
"""使用相应的分词算法对文章进行分词,然后统计每个单词的词频,按照降序返回相应的字典"""
|
||
|
word_dict = defaultdict(int)
|
||
|
if tokenizer == 'jieba':
|
||
|
all_word_list = jieba.cut(text_str)
|
||
|
for word in all_word_list:
|
||
|
if word not in stop_word:
|
||
|
word_dict[word] += 1
|
||
|
elif tokenizer == 'hanlp':
|
||
|
for i in HanLP.segment(text_str):
|
||
|
if i.word not in stop_word and i.nature != 'w':
|
||
|
word_dict[i.word] += 1
|
||
|
else:
|
||
|
print('您输入的 tokenizer 参数有误!')
|
||
|
|
||
|
return {k: v for k, v in sorted(word_dict.items(), key=lambda x: x[1], reverse=True)}
|
||
|
|
||
|
|
||
|
def l2_normal(tf_idf_dict):
|
||
|
"""对计算出来的tf-idf字典进行归一化,归一到(0-1)之间"""
|
||
|
l2_norm = math.sqrt(sum(map(lambda x: x ** 2, tf_idf_dict.values())))
|
||
|
tf_idf_dict1 = sorted(tf_idf_dict.items(), key=lambda x: x[1] / l2_norm, reverse=True)
|
||
|
tf_idf_dict2 = {key: value / l2_norm for key, value in tf_idf_dict1[:15]}
|
||
|
return tf_idf_dict2
|
||
|
|
||
|
|
||
|
def save_result(output_dir, result_dict):
|
||
|
"""
|
||
|
将查重结果字典进行本地化存储
|
||
|
:param output_dir: 结果的输出路径
|
||
|
:param result_dict: 结果字典
|
||
|
:return:
|
||
|
"""
|
||
|
output_path = os.path.join(output_dir, 'check_res.txt')
|
||
|
f1 = open(output_path, 'a', encoding='utf-8')
|
||
|
for doc_id, sent_dict in result_dict.items():
|
||
|
select_sql = """
|
||
|
select title from main_table_paper_detail_message where doc_id='{}'
|
||
|
""".format(str(doc_id))
|
||
|
|
||
|
mysql.cursor.execute(select_sql)
|
||
|
title_name = mysql.cursor.fetchone()[0]
|
||
|
for in_check_sent, out_check_sent_list in sent_dict.items():
|
||
|
f1.write(
|
||
|
in_check_sent + '||||' + "《" + title_name + "》" + '||||' + "[SEP]".join(out_check_sent_list) + '\n')
|
||
|
|
||
|
f1.write('=' * 100 + '\n')
|
||
|
|
||
|
f1.close()
|
||
|
|
||
|
|
||
|
def get_word_vec(word):
|
||
|
"""根据相应的词语,使用模型进行提取词语向量,如果不存在词表中返回0,存在词表中返回对应向量"""
|
||
|
if word in model_word2vec.key_to_index.keys():
|
||
|
vec = model_word2vec.get_vector(word)
|
||
|
else:
|
||
|
try:
|
||
|
vec = model_fasttext.get_vector(word)
|
||
|
except:
|
||
|
return 0
|
||
|
return vec
|
||
|
|
||
|
|
||
|
# 加载 word2vec 模型
|
||
|
word2vec_path = ''
|
||
|
model_word2vec = KeyedVectors.load_word2vec_format(word2vec_path)
|
||
|
fasttext_path = ''
|
||
|
model_fasttext = KeyedVectors.load_word2vec_format(fasttext_path)
|
||
|
stop_word = load_stopwords()
|
||
|
mysql = MysqlConnect(database='zhiwang_class_db')
|
||
|
tr4w = TextRank4Keyword(stop_words_file=stop_word_path)
|