# -*- coding: utf-8 -*-

"""
@Time    :  2023/3/9 18:36
@Author  : 
@FileName: 
@Software: 
@Describe:
"""
#! -*- coding: utf-8 -*-
# 用CRF做中文命名实体识别
# 数据集 http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# 实测验证集的F1可以到96.18%,测试集的F1可以到95.35%
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import os
from src.basemodel import ClassifyModel
import numpy as np
from numpy.linalg import norm
import pandas as pd

# a = [[1, 3, 2], [2, 2, 1]]
# print(cosine_similarity(a))

def cos_sim(a, b):
    A = np.array(a)
    B = np.array(b)
    cosine = np.dot(A, B) / (norm(A) * norm(B))
    return cosine


if __name__ == '__main__':
    maxlen = 512
    batch_size = 32
    # bert配置
    config_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
    checkpoint_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
    dict_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'

    lable_vec_path = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/save_x.npy"
    b = np.load(lable_vec_path)
    df_train_nuoche = pd.read_csv("data/10235513_大型商业建筑人员疏散设计研究_沈福禹/查重.csv", encoding="utf-8").values.tolist()

    classifymodel = ClassifyModel(config_path, checkpoint_path, dict_path, is_train=False, load_weights_path=None)

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
#     while True:
#         text = input("請輸入")
#         data = classifymodel.data_generator([text], batch_size)
#         token, segment = data[0][0], data[1][0]
#         content_cls = classifymodel.predict(token, segment)
#         content_cls = content_cls.reshape(-1)
#         print(content_cls.shape)
#
#         index_list = []
#         for vec in b:
#
#             cos_value = cos_sim(content_cls, vec)
#             index_list.append(cos_value)
#
#         re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)]
#
#         for i in range(0, 10):
#             print(re1[i])
#             print(df_train_nuoche[re1[i][0]])
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    path_txt = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究.txt"
    path_excel = "data/10235513_大型商业建筑人员疏散设计研究_沈福禹/大型商业建筑人员疏散设计研究_2.xlsx"
    f = open(path_txt, encoding="utf-8")
    centent = f.read()
    f.close()

    data_zong = []
    centent_list = centent.split("\n")
    for text in centent_list:
        if text[:5] == "*****":
            continue
        dan_data = [text]
        data = classifymodel.data_generator([text], batch_size)
        token, segment = data[0][0], data[1][0]
        content_cls = classifymodel.predict(token, segment)
        content_cls = content_cls.reshape(-1)

        index_list = []
        for vec in b:

            cos_value = cos_sim(content_cls, vec)
            index_list.append(cos_value)

        re1 = [(i[0],i[1]) for i in sorted(list(enumerate(index_list)), key=lambda x: x[1], reverse=True)]

        for i in range(0, 10):
            dan_data.append(re1[i][1])
            dan_data.append(df_train_nuoche[re1[i][0]][0])
            filename = df_train_nuoche[re1[i][0]][1].split("\\")[-1]
            dan_data.append(filename)
        data_zong.append(dan_data)
    pd.DataFrame(data_zong).to_excel(path_excel, index=None)