#! -*- coding: utf-8 -*-

import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import glob
import random
from tqdm import tqdm
import numpy as np
import pandas as pd
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer, load_vocab
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, open
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model
import tensorflow as tf



from keras.backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) # 此处不同
global graph
graph = tf.get_default_graph()
sess = tf.Session(graph=graph)
set_session(sess)

# global graph,model
# graph = tf.get_default_graph()
# sess = tf.Session(graph=graph)
# K.set_session(sess)


# 基本参数

class TotalLoss(Loss):
    """loss分两部分,一是seq2seq的交叉熵,二是相似度的交叉熵。
    """
    def compute_loss(self, inputs, mask=None):
        loss1 = self.compute_loss_of_seq2seq(inputs, mask)
        loss2 = self.compute_loss_of_similarity(inputs, mask)
        self.add_metric(loss1, name='seq2seq_loss')
        self.add_metric(loss2, name='similarity_loss')
        return loss1 + loss2

    def compute_loss_of_seq2seq(self, inputs, mask=None):
        y_true, y_mask, _, y_pred = inputs
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分
        y_pred = y_pred[:, :-1]  # 预测序列,错开一位
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss

    def compute_loss_of_similarity(self, inputs, mask=None):
        _, _, y_pred, _ = inputs
        y_true = self.get_labels_of_similarity(y_pred)  # 构建标签
        y_pred = K.l2_normalize(y_pred, axis=1)  # 句向量归一化
        similarities = K.dot(y_pred, K.transpose(y_pred))  # 相似度矩阵
        similarities = similarities - K.eye(K.shape(y_pred)[0]) * 1e12  # 排除对角线
        similarities = similarities * 30  # scale
        loss = K.categorical_crossentropy(
            y_true, similarities, from_logits=True
        )
        return loss

    def get_labels_of_similarity(self, y_pred):
        idxs = K.arange(0, K.shape(y_pred)[0])
        idxs_1 = idxs[None, :]
        idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
        labels = K.equal(idxs_1, idxs_2)
        labels = K.cast(labels, K.floatx())
        return labels


class GenerateModel(object):
    def __init__(self):

        self.epoch_acc_vel = 0
        self.config_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
        self.checkpoint_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
        self.dict_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
        self.maxlen = 120
        self.novel_maxlen = 60

    def device_setup(self):
        token_dict, keep_tokens = load_vocab(
            dict_path=self.dict_path,
            simplified=True,
            startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
        )
        tokenizer = Tokenizer(token_dict, do_lower_case=True)

        # model = build_transformer_model(
        #     self.config_path,
        #     self.checkpoint_path,
        #     application='unilm',
        #     keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
        # )

        bert = build_transformer_model(
            self.config_path,
            self.checkpoint_path,
            with_pool='linear',
            application='unilm',
            keep_tokens=keep_tokens,
            return_keras_model=False,
        )

        encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
        seq2seq = keras.models.Model(bert.model.inputs, bert.model.outputs[1])

        # output = CrossEntropy(2)(model.inputs + model.outputs)
        #
        # model = Model(model.inputs, output)
        # model = Model(model.inputs, model.outputs)

        outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs)
        model = keras.models.Model(bert.model.inputs, outputs)

        path_model = './output_simbert_yy/best_simbertmodel_datasim.weights'
        model.load_weights(path_model)

        return encoder,seq2seq, tokenizer


class CrossEntropy(Loss):
    """交叉熵作为loss,并mask掉输入部分
    """
    def compute_loss(self, inputs, mask=None):
        y_true, y_mask, y_pred = inputs
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分
        y_pred = y_pred[:, :-1]  # 预测序列,错开一位
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss

class AutoTitle(AutoRegressiveDecoder):
    """seq2seq解码器
    """
    def __init__(self, model, tokenizer, start_id, end_id, maxlen, minlen=1):
        super(AutoTitle, self).__init__(start_id, end_id, maxlen, minlen)
        self.model = model
        self.tokenizer = tokenizer
        self.start_id = start_id
        self.end_id = end_id
        self.minlen = minlen
        self.models = {}
        if start_id is None:
            self.first_output_ids = np.empty((1, 0), dtype=int)
        else:
            self.first_output_ids = np.array([[self.start_id]])

    def data_generator(self, inputs, output_ids):
        try:
            batch_token_ids, batch_segment_ids = [], []
            if output_ids == []:
                for txt in inputs:
                    token_ids, segment_ids = self.tokenizer.encode(txt, maxlen=120)
                    batch_token_ids.append(token_ids)
                    batch_segment_ids.append(segment_ids)
            else:
                for txt,output_id in zip(inputs, output_ids):
                    token_ids, segment_ids = self.tokenizer.encode(txt, output_id)
                    batch_token_ids.append(token_ids[:-1])
                    batch_segment_ids.append(segment_ids[:-1])

            batch_token_ids = sequence_padding(batch_token_ids)
            batch_segment_ids = sequence_padding(batch_segment_ids)
        except:
            print(inputs,output_ids)
        return batch_token_ids, batch_segment_ids

    def beam_search_batch_(self, inputs, topk, states=None, temperature=1, min_ends=1):
        """beam search解码
        说明:这里的topk即beam size;
        返回:最优解码序列。
        """
        inputs = [np.array([i]) for i in inputs]
        output_ids, output_scores = self.first_output_ids, np.zeros(1)
        for step in range(self.maxlen):
            scores, states = self.predict(
                inputs, output_ids, states, temperature, 'logits'
            )  # 计算当前得分
            if step == 0:  # 第1步预测后将输入重复topk次
                inputs = [np.repeat(i, topk, axis=0) for i in inputs]
            scores = output_scores.reshape((-1, 1)) + scores  # 综合累积得分
            indices = scores.argpartition(-topk, axis=None)[-topk:]  # 仅保留topk
            indices_1 = indices // scores.shape[1]  # 行索引
            indices_2 = (indices % scores.shape[1]).reshape((-1, 1))  # 列索引
            output_ids = np.concatenate([output_ids[indices_1], indices_2],
                                        1)  # 更新输出
            output_scores = np.take_along_axis(
                scores, indices, axis=None
            )  # 更新得分
            is_end = output_ids[:, -1] == self.end_id  # 标记是否以end标记结束
            end_counts = (output_ids == self.end_id).sum(1)  # 统计出现的end标记
            if output_ids.shape[1] >= self.minlen:  # 最短长度判断
                best = output_scores.argmax()  # 得分最大的那个
                if is_end[best] and end_counts[best] >= min_ends:  # 如果已经终止
                    return output_ids[best]  # 直接输出
                else:  # 否则,只保留未完成部分
                    flag = ~is_end | (end_counts < min_ends)  # 标记未完成序列
                    if not flag.all():  # 如果有已完成的
                        inputs = [i[flag] for i in inputs]  # 扔掉已完成序列
                        output_ids = output_ids[flag]  # 扔掉已完成序列
                        output_scores = output_scores[flag]  # 扔掉已完成序列
                        end_counts = end_counts[flag]  # 扔掉已完成end计数
                        topk = flag.sum()  # topk相应变化
        # 达到长度直接输出
        return output_ids[output_scores.argmax()]


    def random_sample_batch(
        self,
        inputs,
        n,
        topk=None,
        topp=None,
        states=None,
        temperature=1,
        min_ends=1
    ):
        """随机采样n个结果
        说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
             表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
        返回:n个解码序列组成的list。
        """
        inputs = [np.array([i for j in i]) for i in inputs]
        output_ids = self.first_output_ids
        results = []
        for step in range(self.maxlen):
            probas, states = self.predict(
                inputs, output_ids, states, temperature, 'probas'
            )  # 计算当前概率
            probas /= probas.sum(axis=1, keepdims=True)  # 确保归一化
            if step == 0:  # 第1步预测后将结果重复n次
                probas = np.repeat(probas, n, axis=0)
                inputs = [np.repeat(i, n, axis=0) for i in inputs]
                output_ids = np.repeat(output_ids, n, axis=0)
            if topk is not None:
                k_indices = probas.argpartition(-topk,
                                                axis=1)[:, -topk:]  # 仅保留topk
                probas = np.take_along_axis(probas, k_indices, axis=1)  # topk概率
                probas /= probas.sum(axis=1, keepdims=True)  # 重新归一化
            if topp is not None:
                p_indices = probas.argsort(axis=1)[:, ::-1]  # 从高到低排序
                probas = np.take_along_axis(probas, p_indices, axis=1)  # 排序概率
                cumsum_probas = np.cumsum(probas, axis=1)  # 累积概率
                flag = np.roll(cumsum_probas >= topp, 1, axis=1)  # 标记超过topp的部分
                flag[:, 0] = False  # 结合上面的np.roll,实现平移一位的效果
                probas[flag] = 0  # 后面的全部置零
                probas /= probas.sum(axis=1, keepdims=True)  # 重新归一化
            sample_func = lambda p: np.random.choice(len(p), p=p)  # 按概率采样函数
            sample_ids = np.apply_along_axis(sample_func, 1, probas)  # 执行采样
            sample_ids = sample_ids.reshape((-1, 1))  # 对齐形状
            if topp is not None:
                sample_ids = np.take_along_axis(
                    p_indices, sample_ids, axis=1
                )  # 对齐原id
            if topk is not None:
                sample_ids = np.take_along_axis(
                    k_indices, sample_ids, axis=1
                )  # 对齐原id
            output_ids = np.concatenate([output_ids, sample_ids], 1)  # 更新输出
            is_end = output_ids[:, -1] == self.end_id  # 标记是否以end标记结束
            end_counts = (output_ids == self.end_id).sum(1)  # 统计出现的end标记
            if output_ids.shape[1] >= self.minlen:  # 最短长度判断
                flag = is_end & (end_counts >= min_ends)  # 标记已完成序列
                if flag.any():  # 如果有已完成的
                    for ids in output_ids[flag]:  # 存好已完成序列
                        results.append(ids)
                    flag = (flag == False)  # 标记未完成序列
                    inputs = [i[flag] for i in inputs]  # 只保留未完成部分输入
                    output_ids = output_ids[flag]  # 只保留未完成部分候选集
                    end_counts = end_counts[flag]  # 只保留未完成部分end计数
                    if len(output_ids) == 0:
                        break
        # 如果还有未完成序列,直接放入结果
        for ids in output_ids:
            results.append(ids)
        # 返回结果
        return results

    def random_sample_and_beam_search(
        self,
        inputs,
        n,
        topk=None,
        topp=None,
        states=None,
        temperature=1,
        min_ends=1
    ):
        """随机采样n个结果
        说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
             表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
        返回:n个解码序列组成的list。
        """
        whether_end_b = False
        results_r = []
        results_b = []
        # index_r = [i for i in range(n)]
        # index_b = [i for i in range(topk)]
        index_r = np.arange(n)
        index_b = np.arange(topk)
        inputs = [np.array([i]) for i in inputs]
        output_ids, output_scores = self.first_output_ids, np.zeros(1)
        results = []
        for step in range(self.maxlen):
            beam_n = len(index_b)
            probas, states = self.predict(
                inputs, output_ids, states, temperature, 'probas'
            )  # 计算当前概率
            probas = probas / probas.sum(axis=1, keepdims=True)  # 确保归一化
            if step == 0:  # 第1步预测后将结果重复n次
                probas = np.repeat(probas, n + topk, axis=0)
                inputs_r = [np.repeat(i, n, axis=0) for i in inputs]
                output_ids = np.repeat(output_ids, n + topk, axis=0)
                inputs_b = [np.repeat(i, topk, axis=0) for i in inputs]
            else:
                if whether_end_b == False:
                    inputs_r = [i[:-beam_n, :] for i in inputs]
                    inputs_b = [i[-beam_n:, :] for i in inputs]
                else:
                    inputs_r = inputs
            if whether_end_b == False:
                probas_r = probas[:-beam_n, :]
            else:
                probas_r = probas
            if step == 0:
                probas_b = probas[0,:]
            else:
                probas_b = probas[-beam_n:, :]

            if whether_end_b == False:
                output_ids_r = output_ids[:-beam_n, :]
                output_ids_b = output_ids[-beam_n:, :]
            else:
                output_ids_r = output_ids
            k_indices = probas_r.argpartition(-topk,
                                            axis=1)[:, -topk:]  # 仅保留topk
            probas_r = np.take_along_axis(probas_r, k_indices, axis=1)  # topk概率
            probas_r /= probas_r.sum(axis=1, keepdims=True)  # 重新归一化

            if whether_end_b == False:
                scores = output_scores.reshape((-1, 1)) + probas_b  # 综合累积得分
                indices = scores.argpartition(-topk, axis=None)[-topk:]  # 仅保留topk
                indices_1 = indices // scores.shape[1]  # 行索引
                indices_2 = (indices % scores.shape[1]).reshape((-1, 1))  # 列索引
                try:
                    output_ids_b = np.concatenate([output_ids_b[indices_1], indices_2],
                                                1)  # 更新输出
                except:
                    print(output_ids_b.shape)
                    print(indices_1)
                    print(indices_2)
                    exit()
                output_scores = np.take_along_axis(
                    scores, indices, axis=None
                )  # 更新得分
            sample_func = lambda p: np.random.choice(len(p), p=p)  # 按概率采样函数
            try:
                sample_ids = np.apply_along_axis(sample_func, 1, probas_r)  # 执行采样
            except:
                print(probas_r)
            sample_ids = sample_ids.reshape((-1, 1))  # 对齐形状
            if topk is not None:
                sample_ids = np.take_along_axis(
                    k_indices, sample_ids, axis=1
                )  # 对齐原id
            output_ids_r = np.concatenate([output_ids_r, sample_ids], 1)  # 更新输出

            # output_ids = np.concatenate([output_ids_r, output_ids_b], 0)
            if whether_end_b == False:
                is_end_r = output_ids_r[:, -1] == self.end_id  # 标记是否以end标记结束
                is_end_b = output_ids_b[:, -1] == self.end_id  # 标记是否以end标记结束
            else:
                is_end_r = output_ids_r[:, -1] == self.end_id

            if whether_end_b == False:
                end_counts_r = (output_ids_r == self.end_id).sum(1)  # 统计出现的end标记
                end_counts_b = (output_ids_b == self.end_id).sum(1)  # 统计出现的end标记
            else:
                end_counts_r = (output_ids_r == self.end_id).sum(1)
            # random_serach
            if output_ids_r.shape[1] >= self.minlen:  # 最短长度判断
                flag = is_end_r & (end_counts_r >= min_ends)  # 标记已完成序列
                if flag.any():  # 如果有已完成的
                    for ids in output_ids_r[flag]:  # 存好已完成序列
                        results_r.append(ids)
                    flag = (flag == False)  # 标记未完成序列
                    try:
                        index_r = index_r[flag]
                    except:
                        print("flag",flag)
                        print("index_r",index_r)
                    inputs_r = [i[flag] for i in inputs_r]  # 只保留未完成部分输入
                    output_ids_r = output_ids_r[flag]  # 只保留未完成部分候选集
                    end_counts_r = end_counts_r[flag]  # 只保留未完成部分end计数

            # beam_serach
            if whether_end_b == False:
                if output_ids_b.shape[1] >= self.minlen:  # 最短长度判断
                    best = output_scores.argmax()  # 得分最大的那个
                    if is_end_b[best] and end_counts_b[best] >= min_ends:  # 如果已经终止
                        results_b.append(output_ids_b[best])  # 直接输出
                        whether_end_b = True
                    else:  # 否则,只保留未完成部分
                        flag_b = ~is_end_b | (end_counts_b < min_ends)  # 标记未完成序列
                        if not flag_b.all():  # 如果有已完成的
                            index_b = index_b[flag_b]
                            inputs_b = [i[flag_b] for i in inputs_b]  # 扔掉已完成序列
                            output_ids_b = output_ids_b[flag_b]  # 扔掉已完成序列
                            output_scores = output_scores[flag_b]  # 扔掉已完成序列
                            end_counts_b = end_counts_b[flag_b]  # 扔掉已完成end计数
                            topk = flag_b.sum()  # topk相应变化

            if whether_end_b == False and len(output_ids_r) != 0:
                token_r = inputs_r[0]
                sample_ids_r = inputs_r[1]
                token_b = inputs_b[0]
                sample_ids_b = inputs_b[1]
                token = np.concatenate([token_r,token_b],0)
                sample_ids = np.concatenate([sample_ids_r,sample_ids_b],0)
                inputs = [token,sample_ids]
                output_ids = np.concatenate([output_ids_r, output_ids_b], 0)
            elif whether_end_b == True and len(output_ids_r) != 0:
                inputs = inputs_r
                output_ids = output_ids_r
            elif whether_end_b == False and len(output_ids_r) == 0:
                inputs = inputs_b
                output_ids = output_ids_b
            else:
                break


        # 如果还有未完成序列,直接放入结果
        for ids in output_ids:
            results.append(ids)
        # 返回结果
        return results_r, results_b

    def beam_search(self, inputs, topk, states=None, temperature=1, min_ends=1):
        """beam search解码
        说明:这里的topk即beam size;
        返回:最优解码序列。
        """
        inputs = [np.array([i]) for i in inputs]
        output_ids, output_scores = self.first_output_ids, np.zeros(1)
        for step in range(self.maxlen):
            scores, states = self.predict(
                inputs, output_ids, states, temperature, 'probas'
            )  # 计算当前得分
            if step == 0:  # 第1步预测后将输入重复topk次
                inputs = [np.repeat(i, topk, axis=0) for i in inputs]
            scores = output_scores.reshape((-1, 1)) + scores  # 综合累积得分
            indices = scores.argpartition(-topk, axis=None)[-topk:]  # 仅保留topk
            indices_1 = indices // scores.shape[1]  # 行索引
            indices_2 = (indices % scores.shape[1]).reshape((-1, 1))  # 列索引
            output_ids = np.concatenate([output_ids[indices_1], indices_2],
                                        1)  # 更新输出
            output_scores = np.take_along_axis(
                scores, indices, axis=None
            )  # 更新得分
            is_end = output_ids[:, -1] == self.end_id  # 标记是否以end标记结束
            end_counts = (output_ids == self.end_id).sum(1)  # 统计出现的end标记
            if output_ids.shape[1] >= self.minlen:  # 最短长度判断
                best = output_scores.argmax()  # 得分最大的那个
                if is_end[best] and end_counts[best] >= min_ends:  # 如果已经终止
                    return output_ids[best]  # 直接输出
                else:  # 否则,只保留未完成部分
                    flag = ~is_end | (end_counts < min_ends)  # 标记未完成序列
                    if not flag.all():  # 如果有已完成的
                        inputs = [i[flag] for i in inputs]  # 扔掉已完成序列
                        output_ids = output_ids[flag]  # 扔掉已完成序列
                        output_scores = output_scores[flag]  # 扔掉已完成序列
                        end_counts = end_counts[flag]  # 扔掉已完成end计数
                        topk = flag.sum()  # topk相应变化
        # 达到长度直接输出
        return output_ids[output_scores.argmax()]

    # def beam_search_batch_ceshi(
    #     self,
    #     inputs_str,
    #     states=None,
    #     temperature=1,
    #     min_ends=1,
    #     num_beam=3
    # ):
    #     """随机采样n个结果
    #     说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
    #          表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
    #     返回:n个解码序列组成的list。
    #     """
    #     output_str = []
    #     # token_ids, segment_ids = self.data_generator(inputs, output_ids)
    #     batch_nums = len(inputs_str)
    #     # output_ids = np.empty((batch_nums, 0), dtype=int)
    #     output_ids = []
    #     generated = [Beamdataone(num_beam, inputs_str[i], self.end_id, self.minlen, min_ends, self.tokenizer, output_ids) for i in range(batch_nums)]
    #     # index_data = [i for i in range(batch_nums)]
    #     for step in range(self.maxlen):
    #         # if step == 0:
    #         #     token_ids, segment_ids = self.data_generator(inputs_str, output_str)
    #         # else:
    #         inputs_str, output_str = [], []
    #         token_ids_batch, segment_ids_batch, output_ids_batch = [], [], []
    #         batch_input_num_beam_num = []
    #         for i in generated:
    #             token_ids, segment_ids = i.token_ids, i.segment_ids
    #             token_ids_batch.append(token_ids)
    #             segment_ids_batch.append(segment_ids)
    #             output_ids_batch.append(i.output_ids)
    #             if step != 0:
    #                 batch_input_num_beam_num.append(token_ids.shape()[0])
    #
    #         token_ids_batch = sequence_padding(token_ids_batch)
    #         segment_ids_batch = sequence_padding(segment_ids_batch)
    #         output_ids_batch = np.array(output_ids_batch)
    #         # if step == 0:
    #
    #         inputs = [token_ids_batch, segment_ids_batch]
    #         if step == 0:
    #             probas = self.predict(
    #                 inputs, 0, states, temperature, 'probas'
    #             )  # 计算当前概率
    #
    #         else:
    #             probas = self.predict(
    #                 inputs, output_ids_batch, states, temperature, 'logits'
    #             )  # 计算当前概率
    #
    #         if step != 0:
    #             num = 0
    #             for index in range(len(batch_input_num_beam_num)-1):
    #                 cc = num
    #                 generated[index].add_data(step, probas[cc:num,:,:])
    #
    #             i = 0
    #             while True:
    #                 if i == len(generated) - 1:
    #                     break
    #                 bool = generated[i].over
    #                 if bool == True:
    #                     generated.pop(i)
    #                 else:
    #                     i += 1
    #         else:
    #             for index in range(len(generated)):
    #                 generated[index].add_data(step, probas[0][index,:])

    def beam_search_batch(
        self,
        inputs_str,
        states=None,
        temperature=1,
        min_ends=1,
        num_beam=3
    ):
        """随机采样n个结果
        说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
             表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
        返回:n个解码序列组成的list。
        """
        output_str = []
        # token_ids, segment_ids = self.data_generator(inputs, output_ids)
        batch_nums = len(inputs_str)
        return_str_batch = [0] * batch_nums
        # output_ids = np.empty((batch_nums, 0), dtype=int)
        output_ids = np.empty((1, 0), dtype=int)
        generated = [Beamdataone(num_beam, i, [inputs_str[i]], self.end_id, self.minlen, min_ends, self.tokenizer, output_ids) for i in range(batch_nums)]
        # index_data = [i for i in range(batch_nums)]
        for step in range(self.maxlen):
            # if step == 0:
            #     token_ids, segment_ids = self.data_generator(inputs_str, output_str)
            # else:
            # inputs_str, output_str = [], []
            text_batch, output_str_batch = [], []
            batch_input_num_beam_num = []
            for i in generated:
                text = i.text
                text_batch.extend(text)
                if i.output_str != "":
                    output_str_batch.extend(i.output_str)
                if step != 0:
                    batch_input_num_beam_num.append(i.num_beams)

            token_ids, segment_ids = self.data_generator(text_batch, output_str_batch)

            # token_ids_batch = sequence_padding(token_ids_batch)
            # segment_ids_batch = sequence_padding(segment_ids_batch)
            # output_ids_batch = np.array(output_ids_batch)
            # if step == 0:

            inputs = [token_ids, segment_ids]

            probas = self.predict_batch(
                    inputs
                )  # 计算当前概率

            probas_new = []
            probas_bool = np.array(token_ids, dtype=bool)
            # np.array(np.where(probas_bool == True))
            for i, sentence in enumerate(probas_bool):
                lie = np.array(np.where(sentence == True))[0]
                probas_new.append(probas[i, lie[-1]])
            probas = np.array(probas_new)


            if step != 0:
                num = 0
                print()
                if len(generated) > 1:
                    index = 0
                    for index in range(len(batch_input_num_beam_num)-1):
                        cc = num
                        num += batch_input_num_beam_num[index]
                        generated[index].add_data(step, probas[cc:num,:])
                    generated[index+1].add_data(step, probas[num:,:])
                else:
                    generated[0].add_data(step, probas[:,:])

            else:
                for index in range(len(generated)):
                    generated[index].add_data(step, probas[index,:])
            # i = 0
            # while True:
            #     bool_ = generated[i].over
            #     if bool_ == True:
            #         one_sentence = generated.pop(i)
            #         return_str_batch[i] = one_sentence.return_str
            #         if i > len(generated) - 1:
            #             break
            #     else:
            #         i += 1
            #         if i > len(generated) - 1:
            #             break

            generated_new = []
            for i in range(len(generated)):
                bool_ = generated[i].over
                if bool_ == False:
                    generated_new.append(generated[i])
                else:
                    return_str_batch[generated[i].batch_id] = generated[i].return_str
            generated = generated_new


            if generated == []:
                return return_str_batch
        return return_str_batch





    def top_batch(
        self,
        inputs_str,
        temperature=1,
        min_ends=1
    ):
        """随机采样n个结果
        说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
             表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
        返回:n个解码序列组成的list。
        """
        output_str = []
        # token_ids, segment_ids = self.data_generator(inputs, output_ids)
        batch_nums = len(inputs_str)
        output_ids = np.empty((batch_nums, 0), dtype=int)

        results = [[] for i in range(batch_nums)]
        index_data = [i for i in range(batch_nums)]
        for step in range(self.maxlen):

            token_ids, segment_ids = self.data_generator(inputs_str, output_str)
            inputs = [token_ids, segment_ids]
            probas = self.predict_batch(
                inputs
            )  # 计算当前概率
            # probas /= probas.sum(axis=1, keepdims=True)  # 确保归一化

            probas_new = []
            probas_bool = np.array(token_ids, dtype=bool)
            # np.array(np.where(probas_bool == True))
            for i,sentence in enumerate(probas_bool):
                lie = np.array(np.where(sentence == True))[0]
                probas_new.append(probas[i,lie[-1]])
            probas = np.array(probas_new)
            k_indices = np.argmax(probas,axis=1)  # 仅保留topk
            k_indices = k_indices.reshape(-1,1)

            sample_ids = k_indices
            output_ids = np.concatenate([output_ids, sample_ids], 1)  # 更新输出
            is_end = output_ids[:, -1] == self.end_id  # 标记是否以end标记结束
            end_counts = (output_ids == self.end_id).sum(1)  # 统计出现的end标记
            if output_ids.shape[1] >= self.minlen:  # 最短长度判断
                flag = is_end & (end_counts >= min_ends)  # 标记已完成序列
                if flag.any():  # 如果有已完成的
                    index = np.array(np.where(flag == True))[0]
                    pop_index = []
                    for i in index:
                        results[index_data[i]] = output_ids[i]
                        pop_index.append(index_data[i])
                    for i in pop_index:
                        index_data.remove(i)
                    # for ids in output_ids[flag]:  # 存好已完成序列
                    #     results.append(ids)
                    flag = (flag == False)  # 标记未完成序列
                    inputs_str = [inputs_str[i] for i in index_data]  # 只保留未完成部分输入
                    output_ids = output_ids[flag]  # 只保留未完成部分候选集
                    if len(output_ids) == 0:
                        break
                    else:
                        output_str = [tokenizer.decode(ids) for ids in output_ids]
                else:
                    output_str = [tokenizer.decode(ids) for ids in output_ids]
        # 如果还有未完成序列,直接放入结果
        # for ids in output_ids:
        #     results.append(ids)
        # 返回结果
        return results


    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        token_ids, segment_ids = inputs

        token_ids = np.concatenate([token_ids, output_ids], 1)
        segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
        with graph.as_default():
            K.set_session(sess)
            nodes = self.last_token(self.model).predict([token_ids, segment_ids])
        return nodes
        # return self.last_token(self.model).predict([token_ids, segment_ids])

    def predict_batch(self, inputs):
        # inputs, output_ids, states, temperature, 'probas'
        token_ids, segment_ids = inputs
        # token_ids = np.concatenate([token_ids, output_ids], 1)
        # segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
        with graph.as_default():
            K.set_session(sess)
            nodes = self.model.predict([token_ids, segment_ids])
        return nodes
        # return self.last_token(self.model).predict([token_ids, segment_ids])

    def generate(self, text, topk=1):
        token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256)
        output_ids = self.beam_search([token_ids, segment_ids],
                                      topk=topk)  # 基于beam search
        return self.tokenizer.decode(output_ids)

    def generate_random(self, text, n=20, topk=5):
        if isinstance(text, list):
            text = text[0]
        token_ids, segment_ids = self.tokenizer.encode(text, maxlen=120)
        output_ids = self.random_sample([token_ids, segment_ids], n, topk)  # 基于随机采样
        return [tokenizer.decode(ids) for ids in output_ids]

    def generate_random_topp(self, text, n=20, topp=0.98):
        if isinstance(text, list):
            text = text[0]
        token_ids, segment_ids = self.tokenizer.encode(text, maxlen=120)
        output_ids = self.random_sample([token_ids, segment_ids], n, topp=topp)  # 基于随机采样
        return [tokenizer.decode(ids) for ids in output_ids]

    def generate_top(self, text):
        output_ids = self.top_batch(text)  # 基于随机采样
        return [tokenizer.decode(ids) for ids in output_ids]

    def generate_beam_search_batch(self, text):
        output_str = self.beam_search_batch(text)  # 基于随机采样
        return output_str

    def generate_random_sample_and_beam_search(self, text, n=20, topk=5):
        text = text[0]
        token_ids, segment_ids = self.tokenizer.encode(text, maxlen=120)
        output_ids_r, output_ids_b = self.random_sample_and_beam_search([token_ids, segment_ids], n=n,
                                      topk=topk)  # 基于beam search
        output_str_r = [self.tokenizer.decode(ids) for ids in output_ids_r]
        output_str_b = [self.tokenizer.decode(ids) for ids in output_ids_b]
        return output_str_r, output_str_b

    def gen_synonyms(self, text, n=20):
        """"含义: 产生sent的n个相似句,然后返回最相似的k个。
        做法:用seq2seq生成,并用encoder算相似度并排序。
        """
        r = self.generate_random_topp(text, n)
        r = [i for i in set(r) if i != text]
        r = [text] + r
        X, S = [], []
        for t in r:
            x, s = tokenizer.encode(t)
            X.append(x)
            S.append(s)
        X = sequence_padding(X)
        S = sequence_padding(S)
        Z = encoder.predict([X, S])
        Z /= (Z ** 2).sum(axis=1, keepdims=True) ** 0.5
        argsort = np.dot(Z[1:], -Z[0]).argsort()
        return [r[i + 1] for i in argsort]


    def gen_synonyms_short(self, text, n=20, len_s = 0.9):
        """"含义: 产生sent的n个相似句,然后返回最相似的k个。
        做法:用seq2seq生成,并用encoder算相似度并排序。
        """
        if isinstance(text, list):
            text = text[0]
        new_text_len = int(len(text) * len_s)
        r = self.generate_random(text, n)
        r = [i for i in set(r) if i != text]
        r = [text] + r
        X, S = [], []
        for t in r:
            x, s = tokenizer.encode(t)
            X.append(x)
            S.append(s)
        X = sequence_padding(X)
        S = sequence_padding(S)
        with graph.as_default():
            K.set_session(sess)
            Z = encoder.predict([X, S])
        Z /= (Z ** 2).sum(axis=1, keepdims=True) ** 0.5
        argsort = np.dot(Z[1:], -Z[0]).argsort()
        sentence_list = [r[i + 1] for i in argsort]

        return_list = []
        for i in sentence_list:
            if len(i) < new_text_len:
                return_list.append(i)
                break

        for i in sentence_list:
            if new_text_len <len(i) < len(text):
                return_list.append(i)
                break
        if return_list != []:
            return return_list[0]
        else:
            return sentence_list[0]


class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring bos_token
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.beams)

    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.num_beams or score > self.worst_score:
            # 可更新的情况:数量未饱和或超过最差得分
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                # 数量饱和需要删掉一个最差的
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len=None):
        """
        相关样本是否已经完成生成。
        best_sum_logprobs是新的候选序列中的最高得分。
        """

        if len(self) < self.num_beams:
            return False
        else:
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            # 是否最高分比当前保存的最低分还差
            ret = self.worst_score >= cur_score
            return ret

class Beamdataone(object):
    def __init__(self, num_beams, batch_id, text, end_id, minlen, min_ends, tokenizer, output_ids):
        """
        Initialize n-best list of hypotheses.
        """
        self.num_beams = num_beams
        self.batch_id = batch_id
        self.beams = []
        self.minlen = minlen
        self.min_ends = min_ends
        self.end_id = end_id
        self.text = text
        self.output_scores = np.zeros(1)
        self.output_ids = output_ids
        self.return_str = ""
        self.over = False
        self.tokenizer = tokenizer
        # self.data()
        self.output_str = ""


    def add_data(self, step, output_probas):
        '''
        还存有的数据,直接可以被迭代,
        @param text:
        @return:
        '''
        # inputs = [np.array([i]) for i in inputs]
        # output_ids, output_scores = self.first_output_ids, np.zeros(1)
        #
        # scores, states = self.predict(
        #     inputs, output_ids, states, temperature, 'logits'
        # )  # 计算当前得分
        # if step == 0:  # 第1步预测后将输入重复topk次
        #     inputs = [np.repeat(i, self.num_beams, axis=0) for i in self.inputs]
        # inputs = [self.token_ids, self.segment_ids]
        # inputs = [np.array([i]) for i in inputs]
        scores = output_probas
        scores = self.output_scores.reshape((-1, 1)) + scores  # 综合累积得分
        indices = scores.argpartition(-self.num_beams, axis=None)[-self.num_beams:]  # 仅保留topk
        indices_1 = indices // scores.shape[1]  # 行索引
        indices_2 = (indices % scores.shape[1]).reshape((-1, 1))  # 列索引
        self.output_ids = np.concatenate([self.output_ids[indices_1], indices_2],
                                    1)  # 更新输出
        self.output_scores = np.take_along_axis(
            scores, indices, axis=None
        )  # 更新得分

        is_end = self.output_ids[:, -1] == self.end_id  # 标记是否以end标记结束
        self.end_counts = (self.output_ids == self.end_id).sum(1)  # 统计出现的end标记
        if self.output_ids.shape[1] >= self.minlen:  # 最短长度判断
            best = self.output_scores.argmax()  # 得分最大的那个
            if is_end[best] and self.end_counts[best] >= self.min_ends:  # 如果已经终止
                # return output_ids[best]  # 直接输出
                self.return_str_main(self.output_ids, best)
                self.over = True
            else:  # 否则,只保留未完成部分
                flag = ~is_end | (self.end_counts < self.min_ends)  # 标记未完成序列
                if not flag.all():  # 如果有已完成的
                    self.output_ids = self.output_ids[flag]  # 扔掉已完成序列
                    self.output_scores = self.output_scores[flag]  # 扔掉已完成序列
                    self.end_counts = self.end_counts[flag]  # 扔掉已完成end计数
                    self.num_beams = flag.sum()  # topk相应变化
                self.output_str = [tokenizer.decode(ids) for ids in self.output_ids]
                self.text = [self.text[0] for i in range(len(self.output_ids))]


    # # 达到长度直接输出
    #     return output_ids[output_scores.argmax()]


    # def data(self):
    #     token_ids, segment_ids = self.tokenizer.encode(self.text, maxlen=256)
    #     self.token_ids = token_ids
    #     self.segment_ids = segment_ids


        # input_str = [text for i in range(self.num_beams)]
        # output_str = self.output_str
        # return input_str, output_str

    def return_str_main(self, output_ids, best):
        output_ids_best = output_ids[best]
        self.return_str = self.tokenizer.decode(output_ids_best)


generatemodel = GenerateModel()
encoder,seq2seq, tokenizer = generatemodel.device_setup()
autotitle = AutoTitle(seq2seq, tokenizer, start_id=None, end_id=tokenizer._token_end_id, maxlen=120)




def just_show(file):
    data = []
    try:
        with open(file, 'r', encoding="utf-8") as f:
            lines = [x.strip() for x in f if x.strip() != '']
    except:
        with open(file, 'r', encoding="gbk") as f:
            lines = [x.strip() for x in f if x.strip() != '']
    # s2 = u'她只能应下来。'
    # lines = pd.read_csv(file,encoding="gbk").values.tolist()
    # random.shuffle(lines)
    # lines = lines[:20]
    for s in tqdm(lines[:2]):
        print(s)
        pre = autotitle.generate_random(s)
        print(s)
        print(pre)
        # data.append([s, pre])
    # pd.DataFrame(data,columns=["原始文本","生成文本"]).to_csv("data/text_测试一万字_unilm_修正数据_小说预训练_全部数据_epoch72_反向训练.csv")


def just_show_sentence(file: list) -> object:
    """
    @param file:list
    """
    text = file[0]
    pre = autotitle.generate(text)
    return pre

def just_show_sentence_batch(file: list) -> object:
    text = file
    pre = autotitle.generate_beam_search_batch(text)
    return pre

def just_show_sentence_top(file: list) -> object:
    text = file
    pre = autotitle.generate_top(text)
    return pre

def just_show_csv_random(file):
    data_new = []
    data = pd.read_csv(file).values.tolist()
    for sentence in tqdm(data):
        sentence = sentence[1]
        print(sentence)
        data_new_dan = []
        data_new_dan.extend([sentence, len(sentence)])
        pre = autotitle.generate_random(sentence)
        for i in pre:
            data_new_dan.extend([i, len(i)])

        data_new.append(data_new_dan)
    pd.DataFrame(data_new).to_csv("data/###第3章 非常尴尬_sim_topK_5.csv")
    # return pre


def just_show_chachong_random(file):
    text = file[0]
    pre = autotitle.gen_synonyms(text)
    return pre


def just_show_csv_beam(file):
    data_new = []
    data = pd.read_csv(file).values.tolist()
    for sentence in tqdm(data):
        sentence = sentence[1]
        print(sentence)
        data_new_dan = []
        data_new_dan.extend([sentence, len(sentence)])
        pre = autotitle.generate([sentence])
        print(pre)
        data_new_dan.extend([pre, len(pre)])
        data_new.append(data_new_dan)
    pd.DataFrame(data_new).to_csv("data/###第3章 非常尴尬_sim_topK_1.csv")


def chulichangju_1(text, chulipangban_return_list):
    fuhao = [",",",","?","!","…"]
    text_1 = text[:60]
    text_2 = text[60:]
    text_1_new = ""
    for i in range(len(text_1)-1, -1, -1):
        if text_1[i] in fuhao:
            text_1_new = text_1[:i]
            text_1_new += text_1[i]
            if len(text_1_new) > 10:
                text_1_new_pre = autotitle.generate(text_1_new)
            else:
                text_1_new_pre = text_1_new
            chulipangban_return_list.append(text_1_new_pre)
            if text_2 != "":
                if i+1 != 60:
                    text_2 = text_1[i+1:] + text_2
            break
        # else:
        #     chulipangban_return_list.append(text_1)
    if text_1_new == "":
        if len(text_1) > 10:
            text_1_new_pre = autotitle.gen_synonyms_short(text_1)
        else:
            text_1_new_pre = text_1
        chulipangban_return_list.append(text_1_new_pre)
    if text_2 != "":
        chulipangban_return_list = chulichangju_1(text_2, chulipangban_return_list)
    return chulipangban_return_list


def chulipangban_test_1(text):
    sentence_list = text.split("。")
    sentence_list_new = []
    for i in sentence_list:
        if i != "":
            sentence_list_new.append(i)
    sentence_list = sentence_list_new
    return_list = []
    for sentence in sentence_list:
        if len(sentence) < 60:
            if len(sentence) > 10:
                sentence_pre = autotitle.generate(sentence)
            else:
                sentence_pre = sentence
            return_list.append(sentence_pre)
        else:
            sentence_split_list = chulichangju_1(sentence,[])
            sentence_split_text = "".join(sentence_split_list)
            return_list.append(sentence_split_text)
    return return_list


def paragraph_test(text, text_new):

    text = chulipangban_test_1(text)
    text = "。".join(text)
    text_new.append(text)

    # text_new_str = "".join(text_new)
    return text_new


if __name__ == '__main__':

    # text = ["历史和当下都证明,创新是民族生存、“发展的不竭源泉”,是是自身发展的必然选择",
    #         "自身发展的必然选择",
    #         "强调轻资产经营, 更加重视经营风险的规避",
    #         "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择",
    #         "是时代对于青年们的深切呼唤"]
    text = ["基本消除“热桥”影响。"]
    print(just_show_sentence(text))
    # print(just_show_sentence_top(text))
    # print(just_show_chachong_random(text))

    # print(tokenizer.encode("\"", maxlen=120))
    # print(just_show_sentence_batch(text))


    # path = "./data/700条论文测试.xlsx"
    # df_list = pd.read_excel(path).values.tolist()
    #
    # df_list_new = []
    # print(len(df_list))
    # for i in tqdm(df_list):
    #     try:
    #         pre = just_show_chachong_random([i[0]])
    #         df_list_new.append([i[0], i[1]] + pre)
    #     except:
    #         print(i[0])
    #         continue
    # df = pd.DataFrame(df_list_new)
    # df.to_excel("./data/700条论文测试_15.xlsx", index=None)