# -*- coding: utf-8 -*- """ @Time : 2023/3/9 18:36 @Author : @FileName: @Software: @Describe: """ #! -*- coding: utf-8 -*- # 用CRF做中文命名实体识别 # 数据集 http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz # 实测验证集的F1可以到96.18%,测试集的F1可以到95.35% import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" import tensorflow as tf import os from src.basemodel import ClassifyModel import numpy as np from numpy.linalg import norm import pandas as pd # a = [[1, 3, 2], [2, 2, 1]] # print(cosine_similarity(a)) def cos_sim(a, b): A = np.array(a) B = np.array(b) cosine = np.dot(A, B) / (norm(A) * norm(B)) return cosine if __name__ == '__main__': maxlen = 512 batch_size = 32 # bert配置 config_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json' checkpoint_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt' dict_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt' lable_vec_path = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/save_x.npy" b = np.load(lable_vec_path) df_train_nuoche = pd.read_csv("data/10235513_大型商业建筑人员疏散设计研究_沈福禹/查重.csv", encoding="utf-8").values.tolist() classifymodel = ClassifyModel(config_path, checkpoint_path, dict_path, is_train=False, load_weights_path=None) # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # while True: # text = input("請輸入") # data = classifymodel.data_generator([text], batch_size) # token, segment = data[0][0], data[1][0] # content_cls = classifymodel.predict(token, segment) # content_cls = content_cls.reshape(-1) # print(content_cls.shape) # # index_list = [] # for vec in b: # # cos_value = cos_sim(content_cls, vec) # index_list.append(cos_value) # # re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)] # # for i in range(0, 10): # print(re1[i]) # print(df_train_nuoche[re1[i][0]]) # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ path_txt = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究.txt" path_excel = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究_2.xlsx" f = open(path_txt, encoding="utf-8") centent = f.read() f.close() data_zong = [] centent_list = centent.split("\n") for text in centent_list: if text[:5] == "*****": continue dan_data = [text] data = classifymodel.data_generator([text], batch_size) token, segment = data[0][0], data[1][0] content_cls = classifymodel.predict(token, segment) content_cls = content_cls.reshape(-1) index_list = [] for vec in b: cos_value = cos_sim(content_cls, vec) index_list.append(cos_value) re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)] for i in range(0, 10): dan_data.append(re1[i][1]) dan_data.append(df_train_nuoche[re1[i][0]][0]) filename = df_train_nuoche[re1[i][0]][1].split("\\")[-1] dan_data.append(filename) data_zong.append(dan_data) pd.DataFrame(data_zong).to_excel(path_excel, index=None)