"""
@Time    :  2022/8/15 15:20
@Author  :
@FileName:
@Software:
@Describe:
"""
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import SpTokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, open
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model
from rouge import Rouge  # pip install rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import difflib


class Evaluator(keras.callbacks.Callback):
    """评估与保存
    """

    def __init__(self):
        self.rouge = Rouge()
        self.smooth = SmoothingFunction().method1
        self.best_bleu = 0.

    # def on_epoch_end(self, epoch, logs=None):
    #     metrics = self.evaluate(valid_data)  # 评测模型
    #     if metrics['bleu'] > self.best_bleu:
    #         self.best_bleu = metrics['bleu']
    #         model.save_weights('./best_model.weights')  # 保存模型
    #     metrics['best_bleu'] = self.best_bleu
    #     print('valid_data:', metrics)

    def evaluate(self, data, topk=1):
        total = 0
        rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
        for title, content in tqdm(data):
            total += 1
            title = ' '.join(title).lower()
            pred_title = ' '.join(autotitle.generate(content,
                                                     topk=topk)).lower()
            if pred_title.strip():
                scores = self.rouge.get_scores(hyps=pred_title, refs=title)
                rouge_1 += scores[0]['rouge-1']['f']
                rouge_2 += scores[0]['rouge-2']['f']
                rouge_l += scores[0]['rouge-l']['f']
                bleu += sentence_bleu(
                    references=[title.split(' ')],
                    hypothesis=pred_title.split(' '),
                    smoothing_function=self.smooth
                )
        rouge_1 /= total
        rouge_2 /= total
        rouge_l /= total
        bleu /= total
        return {
            'rouge-1': rouge_1,
            'rouge-2': rouge_2,
            'rouge-l': rouge_l,
            'bleu': bleu,
        }

    def evaluate_t(self, data_1, data_2, topk=1):

        data_1_eval = ' '.join(data_1)
        data_2_eval = ' '.join(data_2)
        total = 0
        rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0

        scores = self.rouge.get_scores(hyps=[data_1_eval], refs=[data_2_eval])
        rouge_1 += scores[0]['rouge-1']['f']
        rouge_2 += scores[0]['rouge-2']['f']
        rouge_l += scores[0]['rouge-l']['f']
        bleu += sentence_bleu(
            references=[data_1_eval.split(' ')],
            hypothesis=data_2_eval.split(' '),
            smoothing_function=self.smooth
        )
        # rouge_1 /= total
        # rouge_2 /= total
        # rouge_l /= total
        # bleu /= total
        str_sim = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio()
        return [rouge_1, rouge_2, rouge_l, bleu, str_sim]


eval_class = Evaluator()

# print(eval_class.evaluate_t("星 辰 的 话","星 辰 的 话 :"))
path = "data/700条效果对比.xlsx"
path_out = "data/700条效果对比测评结果_14.csv"
data = pd.read_excel(path).values.tolist()

list_class = [0 for i in range(13)]
# print(list_class)
data_new = {"rouge_1": list_class.copy(),
            "rouge_2": list_class.copy(),
            "rouge_l": list_class.copy(),
            "bleu": list_class.copy(),
            "str_sim": list_class.copy()}
total = len(data)

print(len(data))
for i in data:
    dan_list = [i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], i[9], i[10], i[11], i[12], i[-1]]
    for j in range(len(dan_list)):
        eval_list = eval_class.evaluate_t(dan_list[j], i[0])
        try:
            data_new["rouge_1"][j] += eval_list[0]
            data_new["rouge_2"][j] += eval_list[1]
            data_new["rouge_l"][j] += eval_list[2]
            data_new["bleu"][j] += eval_list[3]
            data_new["str_sim"][j] += eval_list[4]
        except:
            pass

data = {}

def fune(x):
    return x/total
for i in data_new:
    data[i] = list(map(fune, data_new[i]))

pd.DataFrame(data,
             index=["simbert_5day",
                    "simbert_simdata4day",
                    "simbert_simdata5day",
                    "simbert_random20_5day",
                    "simbert_simdata4day_yinhao",
                    "simbert_simdata4day_yinhao_dropout",
                    "simsim模型",
                    "dropout_sim_03模型",
                    "dropout_sim_04模型",
                    "t5",
                    "t5_dropout",
                    "小说模型",
                    "yy"]
             ).to_csv(
    path_out)