# -*- coding: utf-8 -*-

"""
@Time    :  2023/1/16 14:59
@Author  : 
@FileName: 
@Software: 
@Describe:
"""
#! -*- coding: utf-8 -*-

import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import glob
from numpy import random
random.seed(1001)
from tqdm import tqdm
import numpy as np
import pandas as pd
import json
import numpy as np
from tqdm import tqdm
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import SpTokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, open
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model
# from rouge import Rouge  # pip install rouge
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import tensorflow as tf


from keras.backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) # 此处不同
global graph
graph = tf.get_default_graph()
sess = tf.Session(graph=graph)
set_session(sess)

# global graph,model
# graph = tf.get_default_graph()
# sess = tf.Session(graph=graph)
# K.set_session(sess)


# 基本参数

class GenerateModel(object):
    def __init__(self):

        self.epoch_acc_vel = 0
        self.config_path = 'mt5/mt5_base/mt5_base_config.json'
        self.checkpoint_path = 'mt5/mt5_base/model.ckpt-1000000'
        self.spm_path = 'mt5/mt5_base/sentencepiece_cn.model'
        self.keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json'
        self.maxlen = 256

    def device_setup(self):
        tokenizer = SpTokenizer(self.spm_path, token_start=None, token_end='</s>')
        keep_tokens = json.load(open(self.keep_tokens_path))

        t5 = build_transformer_model(
            config_path=self.config_path,
            checkpoint_path=self.checkpoint_path,
            keep_tokens=keep_tokens,
            model='mt5.1.1',
            return_keras_model=False,
            name='T5',
        )

        # output = CrossEntropy(2)(model.inputs + model.outputs)
        #
        # model = Model(model.inputs, output)
        encoder = t5.encoder
        decoder = t5.decoder
        model = t5.model
        model.summary()

        output = CrossEntropy(1)([model.inputs[1], model.outputs[0]])

        model = Model(model.inputs, output)
        path_model = "output_t5/best_model_t5.weights"
        model.load_weights(path_model)

        return encoder, decoder, model, tokenizer


class CrossEntropy(Loss):
    """交叉熵作为loss,并mask掉输入部分
    """

    def compute_loss(self, inputs, mask=None):
        y_true, y_pred = inputs
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = K.cast(mask[1], K.floatx())[:, 1:]  # 解码器自带mask
        y_pred = y_pred[:, :-1]  # 预测序列,错开一位
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss


class Beamdataone(object):
    def __init__(self, num_beams, batch_id, text, end_id, minlen, min_ends, tokenizer, output_ids):
        """
        Initialize n-best list of hypotheses.
        """
        self.num_beams = num_beams
        self.batch_id = batch_id
        self.beams = []
        self.minlen = minlen
        self.min_ends = min_ends
        self.end_id = end_id
        self.text = text
        self.output_scores = np.zeros(1)
        self.output_ids = [output_ids]
        self.return_str = ""
        self.over = False
        self.tokenizer = tokenizer
        # self.data()
        self.output_str = ""
        self.text_2_textids(
            self.text
        )
        self.scores = np.zeros(1)
        self.inputs_vector = 0

    def text_2_textids(self,text):
        token_ids, segment_ids = self.tokenizer.encode(text[0], maxlen=120)
        self.text_ids = [token_ids]

    def add_data(self, step, output_scores):
        '''
        还存有的数据,直接可以被迭代,
        @param text:
        @return:
        '''
        # inputs = [np.array([i]) for i in inputs]
        # output_ids, output_scores = self.first_output_ids, np.zeros(1)
        #
        # scores, states = self.predict(
        #     inputs, output_ids, states, temperature, 'logits'
        # )  # 计算当前得分
        # if step == 0:  # 第1步预测后将输入重复topk次
        #     inputs = [np.repeat(i, self.num_beams, axis=0) for i in self.inputs]
        # inputs = [self.token_ids, self.segment_ids]
        # inputs = [np.array([i]) for i in inputs]
        self.output_ids = np.array(self.output_ids)
        if step == 0:  # 第1步预测后将输入重复topk次
            self.text_ids = [np.repeat(i, self.num_beams, axis=0) for i in self.text_ids]
        scores = output_scores.reshape((-1, 1)) + self.scores  # 综合累积得分
        # scores = output_probas
        scores = self.output_scores.reshape((-1, 1)) + scores  # 综合累积得分
        indices = scores.argpartition(-self.num_beams, axis=None)[-self.num_beams:]  # 仅保留topk
        indices_1 = indices // scores.shape[1]  # 行索引
        indices_2 = (indices % scores.shape[1]).reshape((-1, 1))  # 列索引
        self.output_ids = np.concatenate([self.output_ids[indices_1], indices_2],
                                    1)  # 更新输出
        self.output_scores = np.take_along_axis(
            scores, indices, axis=None
        )  # 更新得分

        is_end = self.output_ids[:, -1] == self.end_id  # 标记是否以end标记结束
        self.end_counts = (self.output_ids == self.end_id).sum(1)  # 统计出现的end标记
        if self.output_ids.shape[1] >= self.minlen:  # 最短长度判断
            best = self.output_scores.argmax()  # 得分最大的那个
            if is_end[best] and self.end_counts[best] >= self.min_ends:  # 如果已经终止
                # return output_ids[best]  # 直接输出
                self.return_str_main(self.output_ids, best)
                self.over = True
            else:  # 否则,只保留未完成部分
                flag = ~is_end | (self.end_counts < self.min_ends)  # 标记未完成序列
                if not flag.all():  # 如果有已完成的
                    self.output_ids = self.output_ids[flag]  # 扔掉已完成序列
                    self.output_scores = self.output_scores[flag]  # 扔掉已完成序列
                    self.end_counts = self.end_counts[flag]  # 扔掉已完成end计数
                    self.num_beams = flag.sum()  # topk相应变化
                self.output_ids = self.output_ids.tolist()
                self.output_str = [tokenizer.decode(ids) for ids in self.output_ids]
                self.text_ids = [self.text_ids[0] for i in range(len(self.output_ids))]


    # # 达到长度直接输出
    #     return output_ids[output_scores.argmax()]


    # def data(self):
    #     token_ids, segment_ids = self.tokenizer.encode(self.text, maxlen=256)
    #     self.token_ids = token_ids
    #     self.segment_ids = segment_ids


        # input_str = [text for i in range(self.num_beams)]
        # output_str = self.output_str
        # return input_str, output_str

    def return_str_main(self, output_ids, best):
        output_ids_best = output_ids[best]
        self.return_str = self.tokenizer.decode(output_ids_best)


class AutoTitle(AutoRegressiveDecoder):
    """seq2seq解码器
    """
    def __init__(self, encoder, decoder, model, tokenizer, start_id, end_id, maxlen, minlen=1):
        super(AutoTitle, self).__init__(start_id, end_id, maxlen, minlen)
        self.encoder = encoder
        self.decoder = decoder
        self.model = model
        self.tokenizer = tokenizer
        self.start_id = start_id
        self.end_id = end_id
        self.minlen = minlen
        self.models = {}
        if start_id is None:
            self.first_output_ids = np.empty((1, 0), dtype=int)
        else:
            self.first_output_ids = np.array([[self.start_id]])

    # @AutoRegressiveDecoder.wraps(default_rtype='probas')
    # def predict(self, inputs, output_ids, states):
    #     token_ids, segment_ids = inputs
    #     token_ids = np.concatenate([token_ids, output_ids], 1)
    #     segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
    #     with graph.as_default():
    #         K.set_session(sess)
    #         nodes = self.last_token(self.model).predict([token_ids, segment_ids])
    #     return nodes
    #     # return self.last_token(self.model).predict([token_ids, segment_ids])

    # @AutoRegressiveDecoder.wraps(default_rtype='probas')
    # def predict(self, inputs, output_ids, states):
    #     c_encoded = inputs[0]
    #     with graph.as_default():
    #         K.set_session(sess)
    #         nodes = self.last_token(self.decoder).predict([c_encoded, output_ids])
    #     return nodes

    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        c_encoded = inputs[0]
        with graph.as_default():
            K.set_session(sess)
            nodes = self.last_token(decoder).predict([c_encoded, output_ids])
        return nodes

    def predict_batch(self, inputs):
        # inputs, output_ids, states, temperature, 'probas'
        token_ids, output_ids = inputs
        # token_ids = np.concatenate([token_ids, output_ids], 1)
        # segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
        with graph.as_default():
            K.set_session(sess)
            nodes = self.decoder.predict([token_ids, output_ids])
        return nodes

    def data_generator(self, token_ids, output_ids):

        batch_token_ids = []
        for i,j in zip(token_ids, output_ids):

            batch_token_ids = sequence_padding(token_ids)
            batch_segment_ids = sequence_padding(output_ids)
        return batch_token_ids, batch_segment_ids

    def beam_search_batch(
        self,
        inputs_str,
        states=None,
        temperature=1,
        min_ends=1,
        num_beam=3
    ):
        """随机采样n个结果
        说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
             表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
        返回:n个解码序列组成的list。
        """
        output_str = []
        # token_ids, segment_ids = self.data_generator(inputs, output_ids)
        batch_nums = len(inputs_str)
        return_str_batch = [0] * batch_nums
        # output_ids = np.empty((batch_nums, 0), dtype=int)
        output_ids = np.array([self.start_id])
        generated = [Beamdataone(num_beam, i, [inputs_str[i]], self.end_id, self.minlen, min_ends, self.tokenizer, output_ids) for i in range(batch_nums)]
        # index_data = [i for i in range(batch_nums)]

        c_token_ids = []
        for i in generated:
            text_ids = i.text_ids
            c_token_ids.extend(text_ids)
        c_token_ids = sequence_padding(c_token_ids)
        c_encoded = encoder.predict(np.array(c_token_ids))

        # probas_bool = np.array(token_ids, dtype=bool)
        # # np.array(np.where(probas_bool == True))
        # for i, sentence in enumerate(probas_bool):
        #     lie = np.array(np.where(sentence == True))[0]
        #     probas_new.append(probas[i, lie[-1]])

        for i in range(len(generated)):
            probas_bool = np.array(generated[i].text_ids[0], dtype=bool)
            lie = np.array(np.where(probas_bool == True))[0]
            # c_encoded_dan = c_encoded[i, lie[-1]]
            c_encoded_dan = c_encoded[np.ix_([i], lie)]
            generated[i].inputs_vector = c_encoded_dan[0]


        for step in range(self.maxlen):
            # if step == 0:
            #     token_ids, segment_ids = self.data_generator(inputs_str, output_str)
            # else:
            # inputs_str, output_str = [], []
            inputs_vector_batch, output_ids_batch = [], []
            batch_input_num_beam_num = []
            for i in generated:
                inputs_vector = i.inputs_vector
                # if step != 0:
                #     output_ids_batch.extend(i.output_ids)
                #     text_ids_batch.extend(text_ids)
                # else:
                inputs_vector_batch.append(inputs_vector)
                output_ids_batch.extend(i.output_ids)
                if step != 0:
                    batch_input_num_beam_num.append(i.num_beams)

            # token_ids, output_ids_batch = self.data_generator(inputs_vector_batch, output_ids_batch)

            # token_ids_batch = sequence_padding(token_ids_batch)
            # segment_ids_batch = sequence_padding(segment_ids_batch)
            # output_ids_batch = np.array(output_ids_batch)
            # if step == 0:

            inputs = [inputs_vector_batch, output_ids_batch]

            probas = self.predict_batch(
                    inputs
                )  # 计算当前概率

            probas_new = []
            probas_bool = np.array(inputs_vector_batch, dtype=bool)
            # np.array(np.where(probas_bool == True))
            for i, sentence in enumerate(probas_bool):
                lie = np.array(np.where(sentence == True))[0]
                probas_new.append(probas[i, lie[-1]])
            probas = np.array(probas_new)


            if step != 0:
                num = 0
                if len(generated) > 1:
                    index = 0
                    for index in range(len(batch_input_num_beam_num)-1):
                        cc = num
                        num += batch_input_num_beam_num[index]
                        generated[index].add_data(step, probas[cc:num,:])
                    generated[index+1].add_data(step, probas[num:,:])
                else:
                    generated[0].add_data(step, probas[:,:])

            else:
                for index in range(len(generated)):
                    generated[index].add_data(step, probas[index,:])
            # i = 0
            # while True:
            #     bool_ = generated[i].over
            #     if bool_ == True:
            #         one_sentence = generated.pop(i)
            #         return_str_batch[i] = one_sentence.return_str
            #         if i > len(generated) - 1:
            #             break
            #     else:
            #         i += 1
            #         if i > len(generated) - 1:
            #             break

            generated_new = []
            for i in range(len(generated)):
                bool_ = generated[i].over
                if bool_ == False:
                    generated_new.append(generated[i])
                else:
                    return_str_batch[generated[i].batch_id] = generated[i].return_str
            generated = generated_new


            if generated == []:
                return return_str_batch
        return return_str_batch


    def generate(self, text, topk=5):
        c_token_ids, _ = tokenizer.encode(text, maxlen=120)
        with graph.as_default():
            K.set_session(sess)
            c_encoded = encoder.predict(np.array([c_token_ids]))[0]
        output_ids = self.beam_search([c_encoded], topk=topk)  # 基于beam search
        return tokenizer.decode([int(i) for i in output_ids])

    def generate_random(self, text, n=30, topp=0.9):
        c_token_ids, _ = self.tokenizer.encode(text, maxlen=120)
        with graph.as_default():
            K.set_session(sess)
            c_encoded = self.encoder.predict(np.array([c_token_ids]))[0]
        output_ids = self.random_sample([c_encoded], n, topp=topp)  # 基于随机采样
        text = []
        for ids in output_ids:
            text.append(tokenizer.decode([int(i) for i in ids]))
        return text

    def generate_beam_search_batch(self, text):
        output_str = self.beam_search_batch(text)  # 基于随机采样
        return output_str


generatemodel = GenerateModel()
encoder, decoder, model, tokenizer = generatemodel.device_setup()
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120)




def just_show_sentence(file):
    """
    @param file:list
    """
    text = file[0]
    pre = autotitle.generate(text)
    return pre


def just_show_sentence_batch(file: list) -> object:
    text = file
    pre = autotitle.generate_beam_search_batch(text)
    return pre


if __name__ == '__main__':
    # file = "train_2842.txt"
    # just_show(file)
    # text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是自身发展的必然选择,是时代对于青年们的深切呼唤"]
    # a = just_show_sentence(text)
    # print(a)
    # print(type(a))
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # is_novel = False
    # path = "./data/700条论文测试.xlsx"
    # df_list = pd.read_excel(path).values.tolist()
    #
    #
    # df_list_new = []
    # print(len(df_list))
    # for i in tqdm(df_list):
    #     pre = just_show_sentence([i[0]])
    #
    #     df_list_new.append([i[0], i[1], pre])
    #
    # df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "t5模型"])
    # df.to_excel("./data/700条论文测试_7.xlsx", index=None)

    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    # import os
    #
    # file = "./data/11篇汇总txt_new.txt"
    # file_t5 = "./data/11篇汇总txt_new_predict_t5.txt"
    #
    # try:
    #     with open(file, 'r', encoding="utf-8") as f:
    #         lines = [x.strip() for x in f if x.strip() != '']
    # except:
    #     with open(file, 'r', encoding="gbk") as f:
    #         lines = [x.strip() for x in f if x.strip() != '']
    #
    # zishu = 0
    # data = []
    # for i in tqdm(lines):
    #
    #     zishu += len(i)
    #     pre = just_show_sentence([i])
    #     data.append([i, pre])
    #
    # with open(file_t5, "w", encoding='utf-8') as file:
    #     for i in data:
    #         file.write("\t".join(i) + '\n')
    #     file.close()
    # print(zishu)

    #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    text = ["'李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会",
            "历史和当下都证明,创新是民族生存、“发展的不竭源泉”,是是自身发展的必然选择",
            "自身发展的必然选择",
            "强调轻资产经营, 更加重视经营风险的规避",
            "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择",
            "是时代对于青年们的深切呼唤"]
    # text = ["基本消除“热桥”影响。"]
    print(just_show_sentence(text))
    # print(just_show_sentence_top(text))
    # print(just_show_chachong_random(text))

    # print(tokenizer.encode("\"", maxlen=120))
    # print(just_show_sentence_batch(text))