# -*- coding: utf-8 -*-

"""
@Time    :  2023/3/10 18:53
@Author  : 
@FileName: 
@Software: 
@Describe:
"""
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# import pickle
# import redis
# from redis import ConnectionPool
# app = Flask(__name__)
import numpy as np
import pandas as pd

import json
from keras.layers import *
from tqdm import tqdm
import time
from src.basemodel import ClassifyModel


if __name__ == '__main__':
    maxlen = 256
    batch_size = 32
    # bert配置
    config_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
    checkpoint_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
    dict_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'

    texts = ["我们有个好朋友"] * 34
    print(texts)
    classifymodel = ClassifyModel(config_path, checkpoint_path, dict_path, is_train=False, load_weights_path=None)
    # data = classifymodel.data_generator(texts, batch_size)
    # for token, segment in zip(data[0],data[1]):
    #     print(classifymodel.predict(token, segment).shape)

    df_train_nuoche = pd.read_csv("data/10235513_大型商业建筑人员疏散设计研究_沈福禹/查重.csv",encoding="utf-8")
    Data = []
    for data_dan in df_train_nuoche.values.tolist():
        Data.append(data_dan[0])
    print(Data[0])
    print(len(Data))

    data = classifymodel.data_generator(Data, batch_size)

    print(len(data[0][-1]))
    # print(type(train_generator))
    # d = next(train_generator)
    # print(d)
    a1 = np.empty((0, 768), dtype=int)
    for token, segment in zip(data[0],data[1]):
        a2 = classifymodel.predict(token, segment)
        a1 = np.concatenate([a1, a2])

    print(a1.shape)
    np.save('data/10235513_大型商业建筑人员疏散设计研究_沈福禹/save_x', a1)