You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
1.8 KiB
70 lines
1.8 KiB
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@Time : 2023/2/3 17:27
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
|
|
import os
|
|
# os.environ["TF_KERAS"] = "1"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
from bert4keras.backend import keras, set_gelu
|
|
import numpy as np
|
|
from rouge import Rouge # pip install rouge
|
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
|
|
|
|
|
class Evaluator(keras.callbacks.Callback):
|
|
"""评估与保存
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.rouge = Rouge()
|
|
self.smooth = SmoothingFunction().method1
|
|
self.best_bleu = 0.
|
|
|
|
# def on_epoch_end(self, epoch, logs=None):
|
|
# metrics = self.evaluate(valid_data) # 评测模型
|
|
# if metrics['bleu'] > self.best_bleu:
|
|
# self.best_bleu = metrics['bleu']
|
|
# model.save_weights('./best_model.weights') # 保存模型
|
|
# metrics['best_bleu'] = self.best_bleu
|
|
# print('valid_data:', metrics)
|
|
|
|
|
|
def evaluate_t(self, data_1, data_2, topk=1):
|
|
total = 0
|
|
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
|
|
|
|
scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2])
|
|
rouge_1 += scores[0]['rouge-1']['f']
|
|
rouge_2 += scores[0]['rouge-2']['f']
|
|
rouge_l += scores[0]['rouge-l']['f']
|
|
bleu += sentence_bleu(
|
|
references=[data_1.split(' ')],
|
|
hypothesis=data_2.split(' '),
|
|
smoothing_function=self.smooth
|
|
)
|
|
# rouge_1 /= total
|
|
# rouge_2 /= total
|
|
# rouge_l /= total
|
|
# bleu /= total
|
|
return [rouge_1, rouge_2, rouge_l, bleu]
|
|
|
|
eval_class = Evaluator()
|
|
data_1 = "上海中心大厦"
|
|
data_2 = "上海"
|
|
eval_list = eval_class.evaluate_t(' '.join(data_1), ' '.join(data_2))
|
|
print(eval_list)
|
|
|
|
|
|
a = len(data_2) - len(data_1)
|
|
if a < 0:
|
|
a *
|
|
|
|
|
|
a = len(data_2)/len(data_1)
|
|
np.exp(len(data_2) - len(data_1))
|
|
|