Browse Source

查重总体流程

master
zhangyupeng 2 years ago
parent
commit
ec8162a24b
  1. 47
      CheckPaper.py
  2. 151
      LoadRoformer.py
  3. 194
      SearchSimPaper.py
  4. 155
      SearchSimSentence.py
  5. 146
      util.py

47
CheckPaper.py

@ -0,0 +1,47 @@
# -*- coding = utf-8 -*-
# @Time: 9:59
# @Author:ZYP
# @File:CheckPaper.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
# =========================================================================================
# 查重主函数
# · 进行文档之间的相似度查询
# =========================================================================================
from SearchSimSentence import check_repeat_by_model, check_repeat_by_word2vec
from LoadRoformer import pred_class_num
from SearchSimPaper import search_sim_paper
from util import deal_paper, save_result
def main():
# 查重文档路径
target_paper_path = ''
# 结果输出路径
output_path = ''
# 重复率设定
threshold = 0.85
# 处理成相应的格式 {title:...,abst_zh:...,content:...}
paper_dict = deal_paper(target_paper_path)
# 执行分类模型进行分类(在哪几个类别中进行查重)
class_list = pred_class_num(paper_dict)
# 查出的相似文档,以字典的形式进行返回,{doc_id: 与送检文档的相似度得分}
sim_paper_id_dict = search_sim_paper(paper_dict, class_list)
# 对已经查出来的文档进行逐篇、逐句查重 # {doc_id:{sent1:[sim_sent,...], sent2:[sim_sent,...]}}
result = check_repeat_by_word2vec(paper_dict, sim_paper_id_dict, threshold) # 使用专利中的方法,计算每个词语的权重,最后均值
# result = check_repeat_by_model(paper_dict, sim_paper_id_dict, threshold) # 使用 bert 模型进行句与句中的相似度比较
# 进行结果的保存
save_result(result, output_path)
if __name__ == '__main__':
main()

151
LoadRoformer.py

@ -0,0 +1,151 @@
# -*- coding = utf-8 -*-
# @Time: 16:41
# @Author:ZYP
# @File:LoadRoformer.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
# =========================================================================================
# 加载深度学习模型
# · 加载论文分类模型
# · 加载 BERT 模型
# =========================================================================================
import json
import numpy as np
from bert4keras.models import build_transformer_model
from keras.layers import Lambda, Dense
from keras.models import Model
from bert4keras.tokenizers import Tokenizer
def load_roformer_model(config, ckpt, model_weight_path):
"""加载训练好的168多标签分类模型"""
roformer = build_transformer_model(
config_path=config,
checkpoint_path=ckpt,
model='roformer_v2',
return_keras_model=False)
output = Lambda(lambda x: x[:, 0])(roformer.model.output)
output = Dense(
units=class_nums,
kernel_initializer=roformer.initializer
)(output)
model1 = Model(roformer.model.input, output)
model1.load_weights(model_weight_path)
model1.summary()
return model1
def load_label(label_path1):
"""加载label2id、id2label、每个类别的阈值,用于分类"""
with open(label_path1, 'r', encoding='utf-8') as f:
json_dict = json.load(f)
label2id1 = {i: j[0] for i, j in json_dict.items()}
id2label1 = {j[0]: i for i, j in json_dict.items()}
label_threshold1 = np.array([j[1] for i, j in json_dict.items()])
return label2id1, id2label1, label_threshold1
def encode(text_list1):
"""将文本列表进行循环编码"""
sent_token_id1, sent_segment_id1 = [], []
for index, text in enumerate(text_list1):
if index == 0:
token_id, segment_id = tokenizer_roformer.encode(text)
else:
token_id, segment_id = tokenizer_roformer.encode(text)
token_id = token_id[1:]
segment_id = segment_id[1:]
if (index + 1) % 2 == 0:
segment_id = [1] * len(token_id)
sent_token_id1 += token_id
sent_segment_id1 += segment_id
if len(sent_token_id1) > max_len:
sent_token_id1 = sent_token_id1[:max_len]
sent_segment_id1 = sent_segment_id1[:max_len]
sent_token_id = np.array([sent_token_id1])
sent_segment_id = np.array([sent_segment_id1])
return sent_token_id, sent_segment_id
def load_bert_model(config, ckpt, model_weight_path):
"""加载 BERT 模型"""
bert = build_transformer_model(
config_path=config,
checkpoint_path=ckpt,
model='bert',
return_keras_model=False)
output = Lambda(lambda x: x[:, 0])(bert.model.output)
model1 = Model(bert.model.input, output)
model1.load_weights(model_weight_path)
model1.summary()
return model1
def return_sent_vec(sent_list):
"""
使用 bert 模型将句子列表转化为 句子向量
:param sent_list: 句子的列表
:return: 返回两个值句子的列表对应的句子向量列表
"""
sent_vec_list = []
for sent in sent_list:
token_ids, segment_ids = tokenizer_bert.encode(sent, maxlen=512)
sent_vec = bert_model.predict([np.array([token_ids]), np.array([segment_ids])])
sent_vec_list.append(sent_vec[0].tolist())
return sent_list, sent_vec_list
def pred_class_num(target_paper_dict):
"""将分类的预测结果进行返回,返回对应库的下标,同时对送检论文的要求处理成字典形式,包括 title、key_words、abst_zh、content 等"""
text_list1 = [target_paper_dict['title'], target_paper_dict['key_words']]
abst_zh = target_paper_dict['abst_zh']
if len(abst_zh.split("")) <= 10:
text_list1.append(abst_zh)
else:
text_list1.append("".join(abst_zh.split('')[:5]))
text_list1.append("".join(abst_zh.split('')[-5:]))
sent_token, segment_ids = encode(text_list1)
y_pred = model_roformer.predict([sent_token, segment_ids])
idx = np.where(y_pred[0] > label_threshold, 1, 0)
pred_label_num = [index for index, i in enumerate(idx) if i == 1]
return pred_label_num
# =========================================================================================================================
# roformer 模型的参数
# =========================================================================================================================
class_nums = 168
max_len = 1500
roformer_config_path = ''
roformer_ckpt_path = ''
roformer_vocab_path = ''
roformer_model_weights_path = ''
label_path = '../data/label_threshold.txt'
# roformer 模型的分词器
tokenizer_roformer = Tokenizer(roformer_vocab_path)
# 加载label的相关信息
label2id, id2label, label_threshold = load_label(label_path)
# 加载训练后的分类模型
model_roformer = load_roformer_model(roformer_config_path, roformer_ckpt_path, roformer_model_weights_path)
# =========================================================================================================================
# bert 模型的参数
# =========================================================================================================================
bert_config_path = ''
bert_ckpt_path = ''
bert_vocab_path = ''
bert_model_weight_path = ''
# bert 模型的分词器
tokenizer_bert = Tokenizer(bert_vocab_path)
# 加载 bert 模型进行提取句向量
bert_model = load_bert_model(bert_config_path, bert_ckpt_path, bert_model_weight_path)

194
SearchSimPaper.py

@ -0,0 +1,194 @@
# -*- coding = utf-8 -*-
# @Time: 18:01
# @Author:ZYP
# @File:SearchSimPaper.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
# =========================================================================================
# 查找相似文档
# · 文档之间关键词进行取交集
# · 再对选取出的文档与送检文档进行关键词之间的相似度计算
# · 最终选出最相似的文档,进行排序返回
# =========================================================================================
import math
import numpy as np
from collections import defaultdict
from pymysql.converters import escape_string
from sklearn.metrics.pairwise import cosine_similarity
from util import cut_text, l2_normal, mysql, get_word_vec
def load_inverted_table(class_list):
"""根据分类结果,将每个类别的倒排表进行聚合,返回一个几个类别的字典、几个类别库中总论文数量"""
# 记录总的倒排表 {word:[doc_id1,doc_id2,doc_id3, ...]}
total_inverted_dict1 = {}
# 记录每个类别的论文数量的和
total_nums1 = 0
for label_num in class_list:
select_sql = """
select word, paper_doc_id from word_map_paper_{};
""".format(str(label_num))
mysql.cursor.execute(select_sql)
for word, paper_doc_id in mysql.cursor.fetchall():
if word not in total_inverted_dict1.keys():
total_inverted_dict1[word] = paper_doc_id.split(',')
else:
total_inverted_dict1[word] = sorted(list(set(total_inverted_dict1[word] + paper_doc_id.split(','))),
reverse=False)
select_paper_num_sql = """
select count(*) from main_table_paper_detail_message where label_num={};
""".format(label_num)
mysql.cursor.execute(select_paper_num_sql)
for nums in mysql.cursor.fetchall():
total_nums1 += nums
return total_inverted_dict1, total_nums1
def select_sim_doc_message(sim_doc1):
"""
通过相似的 doc_id 在库中查找相关的信息然后计算每个 doc_id 的均值文档向量以字典形式返回 {文档号均值文档向量....}
:param sim_doc1: 相似文档的列表[doc_id1, doc_id2, ...]
:return: 返回{doc_id:(doc_avg_vec, doc_path)}
"""
all_paper_vec_dict = {}
for doc_id in sim_doc1:
select_sql = """
select tb1.doc_id, tb1.title, tb1.abst_zh, tb2.vsm, tb1.content from
(
(select doc_id, title, abst_zh, content from main_table_paper_detail_message) tb1
left join
(select doc_id, vsm from id_keywords_weights1) tb2
on
tb1.doc_id=tb2.doc_id
)where tb1.doc_id="{}";
""".format(
escape_string(doc_id))
mysql.cursor.execute(select_sql)
sim_doc_id, sim_title, sim_abst, sim_vsm, sim_content_path = mysql.cursor.fetchone()
sim_vsm_dict = {weight.split('=')[0]: float(weight.split('=')[1]) for weight in sim_vsm.split(',')}
vector_paper = []
value_sum = 0.0
for word, weight in sim_vsm_dict.items():
if word in sim_title:
value = 0.5 * weight
elif word in sim_abst:
value = 0.3 * weight
else:
value = 0.2 * weight
word_vec = get_word_vec(word)
if word_vec == 0:
continue
vector_paper.append(word_vec * value)
value_sum += value
# 求一篇文档的关键词的向量均值
# avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / len(vector_paper))
avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / value_sum)
all_paper_vec_dict[doc_id] = (avg_vector, sim_content_path)
return all_paper_vec_dict
def submit_paper_avg_vec(paper_dict1, tf_weight_dict):
"""根据送检的文档的 tf 值,计算这篇文档的均值向量,以 numpy 数组形式返回"""
vector_paper = []
value_sum = 0.0
for word, weight in tf_weight_dict.items():
if word in paper_dict1['title']:
value = 0.5 * weight
elif word in paper_dict1['abst_zh']:
value = 0.3 * weight
else:
value = 0.2 * weight
word_vec = get_word_vec(word)
if word_vec == 0:
continue
vector_paper.append(word_vec * value)
value_sum += value
# avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / len(vector_paper))
avg_vector = np.array(np.sum(np.array(vector_paper, dtype=np.float32), axis=0) / value_sum)
return avg_vector
def compare_sim_in_papers(check_vector, sim_message, top=40):
"""
计算文档间的相似度,使用的是余弦相似度
:param check_vector: 送检文章的文本向量
:param sim_message: 待检测的 50 篇相似文档,以字典形式存储
:param top: 设置返回最相似的 N 篇文档
:return: 返回相似文档的字典 形式{doc_id:(相似得分, 文档路径)}
"""
sim_res_dict = {}
for doc_id, (vector, content_path) in sim_message.items():
# sim_res_dict[doc_id] = cosine_similarity([scalar(check_vector), scalar(vector)])[0][1]
sim_res_dict[doc_id] = (cosine_similarity([check_vector, vector])[0][1], content_path)
_ = sorted(sim_res_dict.items(), key=lambda x: x[1][1], reverse=True)
return {key: value for key, value in _[:top]}
def search_sim_paper(paper_dict, class_list, top=10):
"""
根据送检论文的字典在库中进行相似文档的查询最后返回最相似的 top 文章用于逐句查重
:param paper_dict: 处理好的格式化送检论文
:param class_list: 模型预测送检论文的类别 id 的列表
:param top: 返回前 top 个文档
:return: 返回相似文档的字典 形式{doc_id:(相似得分, 文档路径)}
"""
all_str = paper_dict['title'] + '' + paper_dict['key_words'] + '' + paper_dict['content']
# 合并倒排表,并统计 论文总量 total_inverted_dict:总的倒排表
total_inverted_dict, total_nums = load_inverted_table(class_list)
# 计算送检文档的词频字典{word1:fre1, word2:fre2, ...}
word_fre_dict = cut_text(all_str, tokenizer='jieba')
# 计算送检文档所有词语的 tf-idf 值
tf_idf_dict = {}
for word, freq in word_fre_dict.items():
tf = freq / sum(word_fre_dict.values())
if word in total_inverted_dict.keys():
idf = math.log(total_nums / (len(total_inverted_dict[word]) + 1))
else:
idf = math.log(total_nums / 1)
tf_idf = tf * idf
tf_idf_dict[word] = tf_idf
# 前 15 的单词、权重
tf_dict = l2_normal(tf_idf_dict)
# 统计交集的
count_words_num = defaultdict(int)
for word, weight in tf_dict.items():
if word in total_inverted_dict.keys():
for doc_id in total_inverted_dict[word]:
count_words_num[doc_id] += 1
else:
continue
# 排序
count_word_num = {i: j for i, j in sorted(count_words_num.items(), key=lambda x: x[1], reverse=True)}
# 查找前 50 篇相似的文档
sim_doc = list(count_word_num.keys())[:200]
# 计算这 50 篇文档的 文档均值向量
sim_paper_vec_dict = select_sim_doc_message(sim_doc)
# 计算送检文档的 文档均值向量
submit_vec = submit_paper_avg_vec(paper_dict, tf_dict)
# 计算送检文档 和 查出来的文档的相似度 并排序, 取 top 10 文章用作整篇查重
sim_paper_dict = compare_sim_in_papers(submit_vec, sim_paper_vec_dict, top=top)
return sim_paper_dict

155
SearchSimSentence.py

@ -0,0 +1,155 @@
# -*- coding = utf-8 -*-
# @Time: 2023/3/16 18:27
# @Author:ZYP
# @File:SearchSimSentence.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
# =========================================================================================
# 句子之间的查重
# · 句子之间通过 word2vec、fasttext 词向量模型,进行查重
# · 句子间通过 深度学习模型 进行查重
# =========================================================================================
import os
import re
import json
import numpy as np
from pyhanlp import HanLP
from collections import defaultdict
from CheckPaper.LoadRoformer import return_sent_vec
from CheckPaper.util import stop_word, get_word_vec
from sklearn.metrics.pairwise import cosine_similarity
os.environ['JAVA_HOME'] = '/home/zc-nlp-zyp/work_file/software/jdk1.8.0_341'
def text2vec(paper_dict):
"""
先将句子进行分词然后使用word2vec模型进行向量转化最后加权
:param paper_dict: 句子详细信息
:return:返回{句子向量}
"""
in_check_str = paper_dict['title'] + '' + paper_dict['abst_zh'] + '' + paper_dict['content']
in_check_sent_list = re.split(r'。|,|:|;|!|?', in_check_str)
sent_dict = {}
for sent in in_check_sent_list:
word_list = HanLP.segment(sent)
sent_vec = []
value_sum = 0.0
for i in [word.word for word in word_list if word.word not in stop_word and word.nature != '\w']:
if i in paper_dict['title']:
weight = 0.5
elif i in paper_dict['abst_zh']:
weight = 0.3
else:
weight = 0.2
word_vec = get_word_vec(i)
if word_vec == 0:
continue
vec = (weight * word_vec).tolist()
value_sum += weight
sent_vec.append(vec)
# sent_vec = np.sum(np.array(sent_vec), axis=0) / len(sent_vec) # [1, 300]
sent_vec = np.sum(np.array(sent_vec), axis=0) / value_sum # [1, 300]
sent_dict[sent] = sent_vec.tolist()
return sent_dict
def deal_in_paper(paper_dict):
"""将句子进行分句,然后返回对应的句子列表"""
in_check_str = paper_dict['title'] + '' + paper_dict['abst_zh'] + '' + paper_dict['content']
in_check_sent_list = re.split(r'。|,|:|;|!|?', in_check_str)
return in_check_sent_list
def check_repeat_by_model(paper_dict, sim_paper_id_dict, threshold):
"""
送检文章 相似文章进行查重
:param threshold: 重复率阈值
:param paper_dict: 处理好相应格式的 送检文章字典形式
:param sim_paper_id_dict: 查出来同类的相似文章 格式{doc_id:(相似度得分文档路径)}
:return: 返回每一篇文章相似度大于 85 % 的句子 {doc_id:{原句子[doc_id 下的相似句]}}
"""
res_dict = defaultdict(dict)
# {
# doc_id1:{
# 送检句子1:[相似句1,相似句2,相似句3 ...],
# 送检句子2:[相似句1,相似句2,相似句3 ...]
# }
# doc_id2:{
# 送检句子1:[相似句1,相似句2,相似句3 ...],
# 送检句子2:[相似句1,相似句2,相似句3 ...]
# }
# ...
# }
in_check_sent_list = deal_in_paper(paper_dict)
# 加载模型,将句子转化成向量
in_check_sent, in_check_vec = return_sent_vec(in_check_sent_list)
for doc_id, (_, path) in sim_paper_id_dict.items():
with open(path, 'r', encoding='utf-8') as f:
json_dict = json.load(f)
check_sent_list = deal_in_paper(json_dict)
out_check_sent, out_check_vec = return_sent_vec(check_sent_list)
sim_matrix = cosine_similarity(in_check_vec, out_check_vec)
for index, i in enumerate(sim_matrix):
sim_id = np.where(np.where(i >= threshold, 1, 0) == 1)[0].tolist()
if len(sim_id) != 0:
res_dict[doc_id] = {
res_dict[doc_id][in_check_sent[index]].append(out_check_sent[j]) if in_check_sent[i] in res_dict[
doc_id].keys() else res_dict[doc_id][in_check_sent[index]]: [out_check_sent[j]] for j in sim_id
}
return res_dict
def check_repeat_by_word2vec(paper_dict, sim_paper_id_dict, threshold):
"""
送检文章 相似文章进行查重
:param threshold: 重复率阈值
:param paper_dict: 处理好相应格式的 送检文章字典形式
:param sim_paper_id_dict: 查出来同类的相似文章 格式{doc_id:(相似度得分文档路径)}
:return: 返回每一篇文章相似度大于 85 % 的句子 {doc_id:{原句子[doc_id 下的相似句]}}
"""
in_sent_dict = text2vec(paper_dict) # {sent1:vec1, sent2:vec2, sent3:vec3...}
check_dict = {} # {doc_id1:{sent1:vec1,sent2:vec2...},doc_id2:{sent1:vec1,sent2:vec2...}...}
for doc_id, (_, path) in sim_paper_id_dict.items():
with open(path, 'r', encoding='utf-8') as f:
text_dict = json.load(f)
sent_dict_ = text2vec(text_dict)
check_dict[doc_id] = sent_dict_
in_sent_list = [sent for sent, vec in in_sent_dict.items()]
in_sent_vec_list = [vec for sent, vec in in_sent_dict.items()] # [sent_num_in, 300]
total_result = {} # {doc_id:{sent1:[sim_sent1, sim_sent2...], sent2:[sim_sent1, sim_sent2...]...}}
for doc_id, sent_dict in check_dict.items():
result = {sent: [] for sent, _ in in_sent_dict.items()} # {sent:[(doc_id,sim_sent)]}
every_sent_list = [sent for sent, vec in sent_dict.items()]
every_sent_vec_list = [vec for sent, vec in sent_dict.items()] # [sent_num_every, 300]
sim_score = cosine_similarity(np.array(in_sent_vec_list),
np.array(every_sent_vec_list)) # [sent_num_in, sent_num_every]
for check_index, sim_array in enumerate(sim_score):
sim_id = np.where(np.where(sim_array >= threshold, 1, 0) == 1)[0]
if len(sim_id) != 0:
for i in sim_id:
result[in_sent_list[check_index]].append(every_sent_list[i])
else:
pass
total_result[doc_id] = result
return total_result

146
util.py

@ -0,0 +1,146 @@
# -*- coding = utf-8 -*-
# @Time: 18:02
# @Author:ZYP
# @File:util.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm
# =========================================================================================
# 工具类
# 用于加载停用词、数据库、word2vec、fasttext模型
# =========================================================================================
import os
import math
import jieba
import pymysql
from pyhanlp import HanLP
from collections import defaultdict
from textrank4zh import TextRank4Keyword
from gensim.models.keyedvectors import KeyedVectors
stop_word_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/fasttext_train/data/total_stopwords.txt'
jieba.load_userdict('/home/zc-nlp-zyp/work_file/ssd_data/program/check_paper/fasttext_train/data'
'/user_dict_final_230316.txt')
os.environ['JAVA_HOME'] = '/home/zc-nlp-zyp/work_file/software/jdk1.8.0_341'
def deal_paper(target_paper_path):
"""根据不同格式的论文进行相应的清洗策略,将杂乱的文本处理成字典形式,分为 题目、摘要、正文 等,然后返回字典格式"""
paper_dict = {}
"""
具体的清洗策略具体情况具体分析清洗等
"""
return paper_dict
class MysqlConnect:
"""mysql 的连接类,创建 mysql 连接对象"""
def __init__(self, host='localhost', user='root', passwd='123456', database='zhiwang_class_db', charset='utf8'):
self.conn = pymysql.connect(host=host, user=user, passwd=passwd, database=database, charset=charset)
self.cursor = self.conn.cursor()
def implement_sql(self, sql, is_close=True):
"""向sql中插入数据,查询完成后关闭连接"""
self.cursor.execute(sql)
self.conn.commit()
if is_close:
self.cursor.close()
self.conn.close()
def select_sql(self, sql, is_close=True):
"""向sql中插入数据,查询完成后关闭连接"""
self.cursor.execute(sql)
res = [i for i in self.cursor.fetchall()]
if is_close:
self.cursor.close()
self.conn.close()
return res
def close_connect(self):
self.cursor.close()
self.conn.close()
def load_stopwords(path=stop_word_path):
"""加载停用词"""
with open(path, 'r', encoding='utf-8') as f:
stop_words = {i.strip() for i in f.readlines()}
return stop_words
def cut_text(text_str, tokenizer='jieba'):
"""使用相应的分词算法对文章进行分词,然后统计每个单词的词频,按照降序返回相应的字典"""
word_dict = defaultdict(int)
if tokenizer == 'jieba':
all_word_list = jieba.cut(text_str)
for word in all_word_list:
if word not in stop_word:
word_dict[word] += 1
elif tokenizer == 'hanlp':
for i in HanLP.segment(text_str):
if i.word not in stop_word and i.nature != 'w':
word_dict[i.word] += 1
else:
print('您输入的 tokenizer 参数有误!')
return {k: v for k, v in sorted(word_dict.items(), key=lambda x: x[1], reverse=True)}
def l2_normal(tf_idf_dict):
"""对计算出来的tf-idf字典进行归一化,归一到(0-1)之间"""
l2_norm = math.sqrt(sum(map(lambda x: x ** 2, tf_idf_dict.values())))
tf_idf_dict1 = sorted(tf_idf_dict.items(), key=lambda x: x[1] / l2_norm, reverse=True)
tf_idf_dict2 = {key: value / l2_norm for key, value in tf_idf_dict1[:15]}
return tf_idf_dict2
def save_result(output_dir, result_dict):
"""
将查重结果字典进行本地化存储
:param output_dir: 结果的输出路径
:param result_dict: 结果字典
:return:
"""
output_path = os.path.join(output_dir, 'check_res.txt')
f1 = open(output_path, 'a', encoding='utf-8')
for doc_id, sent_dict in result_dict.items():
select_sql = """
select title from main_table_paper_detail_message where doc_id='{}'
""".format(str(doc_id))
mysql.cursor.execute(select_sql)
title_name = mysql.cursor.fetchone()[0]
for in_check_sent, out_check_sent_list in sent_dict.items():
f1.write(
in_check_sent + '||||' + "" + title_name + "" + '||||' + "[SEP]".join(out_check_sent_list) + '\n')
f1.write('=' * 100 + '\n')
f1.close()
def get_word_vec(word):
"""根据相应的词语,使用模型进行提取词语向量,如果不存在词表中返回0,存在词表中返回对应向量"""
if word in model_word2vec.key_to_index.keys():
vec = model_word2vec.get_vector(word)
else:
try:
vec = model_fasttext.get_vector(word)
except:
return 0
return vec
# 加载 word2vec 模型
word2vec_path = ''
model_word2vec = KeyedVectors.load_word2vec_format(word2vec_path)
fasttext_path = ''
model_fasttext = KeyedVectors.load_word2vec_format(fasttext_path)
stop_word = load_stopwords()
mysql = MysqlConnect(database='zhiwang_class_db')
tr4w = TextRank4Keyword(stop_words_file=stop_word_path)
Loading…
Cancel
Save