You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.6 KiB
100 lines
3.6 KiB
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@Time : 2023/3/9 18:36
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
#! -*- coding: utf-8 -*-
|
|
# 用CRF做中文命名实体识别
|
|
# 数据集 http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
|
|
# 实测验证集的F1可以到96.18%,测试集的F1可以到95.35%
|
|
import os
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
import tensorflow as tf
|
|
import os
|
|
from src.basemodel import ClassifyModel
|
|
import numpy as np
|
|
from numpy.linalg import norm
|
|
import pandas as pd
|
|
|
|
# a = [[1, 3, 2], [2, 2, 1]]
|
|
# print(cosine_similarity(a))
|
|
|
|
def cos_sim(a, b):
|
|
A = np.array(a)
|
|
B = np.array(b)
|
|
cosine = np.dot(A, B) / (norm(A) * norm(B))
|
|
return cosine
|
|
|
|
|
|
if __name__ == '__main__':
|
|
maxlen = 512
|
|
batch_size = 32
|
|
# bert配置
|
|
config_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
|
|
checkpoint_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
|
|
dict_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
|
|
|
|
lable_vec_path = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/save_x.npy"
|
|
b = np.load(lable_vec_path)
|
|
df_train_nuoche = pd.read_csv("data/10235513_大型商业建筑人员疏散设计研究_沈福禹/查重.csv", encoding="utf-8").values.tolist()
|
|
|
|
classifymodel = ClassifyModel(config_path, checkpoint_path, dict_path, is_train=False, load_weights_path=None)
|
|
|
|
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
# while True:
|
|
# text = input("請輸入")
|
|
# data = classifymodel.data_generator([text], batch_size)
|
|
# token, segment = data[0][0], data[1][0]
|
|
# content_cls = classifymodel.predict(token, segment)
|
|
# content_cls = content_cls.reshape(-1)
|
|
# print(content_cls.shape)
|
|
#
|
|
# index_list = []
|
|
# for vec in b:
|
|
#
|
|
# cos_value = cos_sim(content_cls, vec)
|
|
# index_list.append(cos_value)
|
|
#
|
|
# re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)]
|
|
#
|
|
# for i in range(0, 10):
|
|
# print(re1[i])
|
|
# print(df_train_nuoche[re1[i][0]])
|
|
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
path_txt = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究.txt"
|
|
path_excel = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究_2.xlsx"
|
|
f = open(path_txt, encoding="utf-8")
|
|
centent = f.read()
|
|
f.close()
|
|
|
|
data_zong = []
|
|
centent_list = centent.split("\n")
|
|
for text in centent_list:
|
|
if text[:5] == "*****":
|
|
continue
|
|
dan_data = [text]
|
|
data = classifymodel.data_generator([text], batch_size)
|
|
token, segment = data[0][0], data[1][0]
|
|
content_cls = classifymodel.predict(token, segment)
|
|
content_cls = content_cls.reshape(-1)
|
|
|
|
index_list = []
|
|
for vec in b:
|
|
|
|
cos_value = cos_sim(content_cls, vec)
|
|
index_list.append(cos_value)
|
|
|
|
re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)]
|
|
|
|
for i in range(0, 10):
|
|
dan_data.append(re1[i][1])
|
|
dan_data.append(df_train_nuoche[re1[i][0]][0])
|
|
filename = df_train_nuoche[re1[i][0]][1].split("\\")[-1]
|
|
dan_data.append(filename)
|
|
data_zong.append(dan_data)
|
|
pd.DataFrame(data_zong).to_excel(path_excel, index=None)
|
|
|