# -*- coding: utf-8 -*-

"""
@Time    :  2022/8/5 13:41
@Author  : 
@FileName: 
@Software: 
@Describe:
"""

# ! -*- coding: utf-8 -*-
# 微调多国语言版T5做Seq2Seq任务
# 介绍链接:https://kexue.fm/archives/7867
# 细节请看:https://github.com/bojone/t5_in_bert4keras
# 数据集:https://github.com/CLUEbenchmark/CLGE 中的CSL数据集
# 补充了评测指标bleu、rouge-1、rouge-2、rouge-l
import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import json
import numpy as np
from tqdm import tqdm
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import SpTokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, open
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model
# from rouge import Rouge  # pip install rouge
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)


# 基本参数
max_c_len = 128
max_t_len = 128
batch_size = 28
epochs = 10

# 模型路径
config_path = 'mt5/mt5_base_dropout_0_3_config.json'
checkpoint_path = 'mt5/mt5_base/model.ckpt-1000000'
spm_path = 'mt5/mt5_base/sentencepiece_cn.model'
keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json'


file = "data/train_new/train_yy.txt"
try:
    with open(file, 'r', encoding="utf-8") as f:
        lines = [x.strip() for x in f if x.strip() != '']
except:
    with open(file, 'r', encoding="gbk") as f:
        lines = [x.strip() for x in f if x.strip() != '']
# 加载分词器
tokenizer = SpTokenizer(spm_path, token_start=None, token_end='</s>')
keep_tokens = json.load(open(keep_tokens_path))


class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        batch_c_token_ids, batch_t_token_ids = [], []
        for is_end, txt in self.sample(random):
            text = txt.split('\t')
            if len(text) == 3:
                content = text[0]
                content_g = text[2]
                c_token_ids, _ = tokenizer.encode(content, maxlen=max_c_len)
                t_token_ids, _ = tokenizer.encode(content_g, maxlen=max_t_len)
                # token_ids, segment_ids = tokenizer.encode(
                #     content, content_g, maxlen=max_c_len
                # )
                batch_c_token_ids.append(c_token_ids)
                batch_t_token_ids.append([0] + t_token_ids)
                if len(batch_c_token_ids) == self.batch_size or is_end:
                    batch_c_token_ids = sequence_padding(batch_c_token_ids)
                    batch_t_token_ids = sequence_padding(batch_t_token_ids)
                    yield [batch_c_token_ids, batch_t_token_ids], None
                    batch_c_token_ids, batch_t_token_ids = [], []


class CrossEntropy(Loss):
    """交叉熵作为loss,并mask掉输入部分
    """
    def compute_loss(self, inputs, mask=None):
        y_true, y_pred = inputs
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = K.cast(mask[1], K.floatx())[:, 1:]  # 解码器自带mask
        y_pred = y_pred[:, :-1]  # 预测序列,错开一位
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss


t5 = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    keep_tokens=keep_tokens,
    model='mt5.1.1',
    return_keras_model=False,
    name='T5',
)

encoder = t5.encoder
decoder = t5.decoder
model = t5.model
model.summary()

output = CrossEntropy(1)([model.inputs[1], model.outputs[0]])

model = Model(model.inputs, output)
model.compile(optimizer=Adam(2e-4))
# path_model = "output_t5/best_model_t5_juben.weights"
# model.load_weights(path_model)

class AutoTitle(AutoRegressiveDecoder):
    """seq2seq解码器
    """

    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        c_encoded = inputs[0]
        return self.last_token(decoder).predict([c_encoded, output_ids])

    def generate(self, text, topk=1):
        c_token_ids, _ = tokenizer.encode(text, maxlen=max_c_len)
        c_encoded = encoder.predict(np.array([c_token_ids]))[0]
        output_ids = self.beam_search([c_encoded], topk=topk)  # 基于beam search
        return tokenizer.decode([int(i) for i in output_ids])


# 注:T5有一个很让人不解的设置,它的<bos>标记id是0,即<bos>和<pad>其实都是0
autotitle = AutoTitle(
    start_id=0, end_id=tokenizer._token_end_id, maxlen=max_t_len
)


# class Evaluator(keras.callbacks.Callback):
#     """评估与保存
#     """
#
#     def __init__(self):
#         self.rouge = Rouge()
#         self.smooth = SmoothingFunction().method1
#         self.best_bleu = 0.
#
#     def on_epoch_end(self, epoch, logs=None):
#         metrics = self.evaluate(valid_data)  # 评测模型
#         if metrics['bleu'] > self.best_bleu:
#             self.best_bleu = metrics['bleu']
#             model.save_weights('./best_model.weights')  # 保存模型
#         metrics['best_bleu'] = self.best_bleu
#         print('valid_data:', metrics)
#
#     def evaluate(self, data, topk=1):
#         total = 0
#         rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
#         for title, content in tqdm(data):
#             total += 1
#             title = ' '.join(title).lower()
#             pred_title = ' '.join(autotitle.generate(content,
#                                                      topk=topk)).lower()
#             if pred_title.strip():
#                 scores = self.rouge.get_scores(hyps=pred_title, refs=title)
#                 rouge_1 += scores[0]['rouge-1']['f']
#                 rouge_2 += scores[0]['rouge-2']['f']
#                 rouge_l += scores[0]['rouge-l']['f']
#                 bleu += sentence_bleu(
#                     references=[title.split(' ')],
#                     hypothesis=pred_title.split(' '),
#                     smoothing_function=self.smooth
#                 )
#         rouge_1 /= total
#         rouge_2 /= total
#         rouge_l /= total
#         bleu /= total
#         return {
#             'rouge-1': rouge_1,
#             'rouge-2': rouge_2,
#             'rouge-l': rouge_l,
#             'bleu': bleu,
#         }

def just_show():

    s2 = "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"
    for s in [s2]:
        print(u'生成标题:', autotitle.generate(s))
    print()


class Evaluator(keras.callbacks.Callback):
    """评估与保存
    """
    def __init__(self):
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 保存最优
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save_weights('./output_t5/best_model_t5_0724.weights')
        # 演示效果7
        just_show()


if __name__ == '__main__':

    evaluator = Evaluator()
    train_generator = data_generator(lines, batch_size)

    model.fit(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=epochs,
        callbacks=[evaluator]
    )

else:
    model.load_weights('./best_model.weights')