# -*- coding: utf-8 -*-

"""
@Time    :  2023/2/3 17:27
@Author  : 
@FileName: 
@Software: 
@Describe:
"""

import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from bert4keras.backend import keras, set_gelu
import numpy as np
from rouge import Rouge  # pip install rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction


class Evaluator(keras.callbacks.Callback):
    """评估与保存
    """

    def __init__(self):
        self.rouge = Rouge()
        self.smooth = SmoothingFunction().method1
        self.best_bleu = 0.

    # def on_epoch_end(self, epoch, logs=None):
    #     metrics = self.evaluate(valid_data)  # 评测模型
    #     if metrics['bleu'] > self.best_bleu:
    #         self.best_bleu = metrics['bleu']
    #         model.save_weights('./best_model.weights')  # 保存模型
    #     metrics['best_bleu'] = self.best_bleu
    #     print('valid_data:', metrics)


    def evaluate_t(self, data_1, data_2, topk=1):
        total = 0
        rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0

        scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2])
        rouge_1 += scores[0]['rouge-1']['f']
        rouge_2 += scores[0]['rouge-2']['f']
        rouge_l += scores[0]['rouge-l']['f']
        bleu += sentence_bleu(
            references=[data_1.split(' ')],
            hypothesis=data_2.split(' '),
            smoothing_function=self.smooth
        )
        # rouge_1 /= total
        # rouge_2 /= total
        # rouge_l /= total
        # bleu /= total
        return [rouge_1, rouge_2, rouge_l, bleu]

eval_class = Evaluator()
data_1 = "上海中心大厦"
data_2 = "上海"
eval_list = eval_class.evaluate_t(' '.join(data_1), ' '.join(data_2))
print(eval_list)


a = len(data_2) - len(data_1)
if a < 0:
    a *


a = len(data_2)/len(data_1)
np.exp(len(data_2) - len(data_1))