You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

100 lines
3.6 KiB

# -*- coding: utf-8 -*-
"""
@Time : 2023/3/9 18:36
@Author :
@FileName:
@Software:
@Describe:
"""
#! -*- coding: utf-8 -*-
# 用CRF做中文命名实体识别
# 数据集 http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# 实测验证集的F1可以到96.18%,测试集的F1可以到95.35%
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import os
from src.basemodel import ClassifyModel
import numpy as np
from numpy.linalg import norm
import pandas as pd
# a = [[1, 3, 2], [2, 2, 1]]
# print(cosine_similarity(a))
def cos_sim(a, b):
A = np.array(a)
B = np.array(b)
cosine = np.dot(A, B) / (norm(A) * norm(B))
return cosine
if __name__ == '__main__':
maxlen = 512
batch_size = 32
# bert配置
config_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
lable_vec_path = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/save_x.npy"
b = np.load(lable_vec_path)
df_train_nuoche = pd.read_csv("data/10235513_大型商业建筑人员疏散设计研究_沈福禹/查重.csv", encoding="utf-8").values.tolist()
classifymodel = ClassifyModel(config_path, checkpoint_path, dict_path, is_train=False, load_weights_path=None)
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# while True:
# text = input("請輸入")
# data = classifymodel.data_generator([text], batch_size)
# token, segment = data[0][0], data[1][0]
# content_cls = classifymodel.predict(token, segment)
# content_cls = content_cls.reshape(-1)
# print(content_cls.shape)
#
# index_list = []
# for vec in b:
#
# cos_value = cos_sim(content_cls, vec)
# index_list.append(cos_value)
#
# re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)]
#
# for i in range(0, 10):
# print(re1[i])
# print(df_train_nuoche[re1[i][0]])
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
path_txt = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究.txt"
path_excel = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究_2.xlsx"
f = open(path_txt, encoding="utf-8")
centent = f.read()
f.close()
data_zong = []
centent_list = centent.split("\n")
for text in centent_list:
if text[:5] == "*****":
continue
dan_data = [text]
data = classifymodel.data_generator([text], batch_size)
token, segment = data[0][0], data[1][0]
content_cls = classifymodel.predict(token, segment)
content_cls = content_cls.reshape(-1)
index_list = []
for vec in b:
cos_value = cos_sim(content_cls, vec)
index_list.append(cos_value)
re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)]
for i in range(0, 10):
dan_data.append(re1[i][1])
dan_data.append(df_train_nuoche[re1[i][0]][0])
filename = df_train_nuoche[re1[i][0]][1].split("\\")[-1]
dan_data.append(filename)
data_zong.append(dan_data)
pd.DataFrame(data_zong).to_excel(path_excel, index=None)