You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
4.8 KiB
147 lines
4.8 KiB
"""
|
|
@Time : 2022/8/15 15:20
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from bert4keras.backend import keras, K
|
|
from bert4keras.layers import Loss
|
|
from bert4keras.models import build_transformer_model
|
|
from bert4keras.tokenizers import SpTokenizer
|
|
from bert4keras.optimizers import Adam
|
|
from bert4keras.snippets import sequence_padding, open
|
|
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
|
|
from keras.models import Model
|
|
from rouge import Rouge # pip install rouge
|
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
|
import difflib
|
|
|
|
|
|
class Evaluator(keras.callbacks.Callback):
|
|
"""评估与保存
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.rouge = Rouge()
|
|
self.smooth = SmoothingFunction().method1
|
|
self.best_bleu = 0.
|
|
|
|
# def on_epoch_end(self, epoch, logs=None):
|
|
# metrics = self.evaluate(valid_data) # 评测模型
|
|
# if metrics['bleu'] > self.best_bleu:
|
|
# self.best_bleu = metrics['bleu']
|
|
# model.save_weights('./best_model.weights') # 保存模型
|
|
# metrics['best_bleu'] = self.best_bleu
|
|
# print('valid_data:', metrics)
|
|
|
|
def evaluate(self, data, topk=1):
|
|
total = 0
|
|
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
|
|
for title, content in tqdm(data):
|
|
total += 1
|
|
title = ' '.join(title).lower()
|
|
pred_title = ' '.join(autotitle.generate(content,
|
|
topk=topk)).lower()
|
|
if pred_title.strip():
|
|
scores = self.rouge.get_scores(hyps=pred_title, refs=title)
|
|
rouge_1 += scores[0]['rouge-1']['f']
|
|
rouge_2 += scores[0]['rouge-2']['f']
|
|
rouge_l += scores[0]['rouge-l']['f']
|
|
bleu += sentence_bleu(
|
|
references=[title.split(' ')],
|
|
hypothesis=pred_title.split(' '),
|
|
smoothing_function=self.smooth
|
|
)
|
|
rouge_1 /= total
|
|
rouge_2 /= total
|
|
rouge_l /= total
|
|
bleu /= total
|
|
return {
|
|
'rouge-1': rouge_1,
|
|
'rouge-2': rouge_2,
|
|
'rouge-l': rouge_l,
|
|
'bleu': bleu,
|
|
}
|
|
|
|
def evaluate_t(self, data_1, data_2, topk=1):
|
|
|
|
data_1_eval = ' '.join(data_1)
|
|
data_2_eval = ' '.join(data_2)
|
|
total = 0
|
|
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
|
|
|
|
scores = self.rouge.get_scores(hyps=[data_1_eval], refs=[data_2_eval])
|
|
rouge_1 += scores[0]['rouge-1']['f']
|
|
rouge_2 += scores[0]['rouge-2']['f']
|
|
rouge_l += scores[0]['rouge-l']['f']
|
|
bleu += sentence_bleu(
|
|
references=[data_1_eval.split(' ')],
|
|
hypothesis=data_2_eval.split(' '),
|
|
smoothing_function=self.smooth
|
|
)
|
|
# rouge_1 /= total
|
|
# rouge_2 /= total
|
|
# rouge_l /= total
|
|
# bleu /= total
|
|
str_sim = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio()
|
|
return [rouge_1, rouge_2, rouge_l, bleu, str_sim]
|
|
|
|
|
|
eval_class = Evaluator()
|
|
|
|
# print(eval_class.evaluate_t("星 辰 的 话","星 辰 的 话 :"))
|
|
path = "data/700条效果对比.xlsx"
|
|
path_out = "data/700条效果对比测评结果_14.csv"
|
|
data = pd.read_excel(path).values.tolist()
|
|
|
|
list_class = [0 for i in range(13)]
|
|
# print(list_class)
|
|
data_new = {"rouge_1": list_class.copy(),
|
|
"rouge_2": list_class.copy(),
|
|
"rouge_l": list_class.copy(),
|
|
"bleu": list_class.copy(),
|
|
"str_sim": list_class.copy()}
|
|
total = len(data)
|
|
|
|
print(len(data))
|
|
for i in data:
|
|
dan_list = [i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], i[9], i[10], i[11], i[12], i[-1]]
|
|
for j in range(len(dan_list)):
|
|
eval_list = eval_class.evaluate_t(dan_list[j], i[0])
|
|
try:
|
|
data_new["rouge_1"][j] += eval_list[0]
|
|
data_new["rouge_2"][j] += eval_list[1]
|
|
data_new["rouge_l"][j] += eval_list[2]
|
|
data_new["bleu"][j] += eval_list[3]
|
|
data_new["str_sim"][j] += eval_list[4]
|
|
except:
|
|
pass
|
|
|
|
data = {}
|
|
|
|
def fune(x):
|
|
return x/total
|
|
for i in data_new:
|
|
data[i] = list(map(fune, data_new[i]))
|
|
|
|
pd.DataFrame(data,
|
|
index=["simbert_5day",
|
|
"simbert_simdata4day",
|
|
"simbert_simdata5day",
|
|
"simbert_random20_5day",
|
|
"simbert_simdata4day_yinhao",
|
|
"simbert_simdata4day_yinhao_dropout",
|
|
"simsim模型",
|
|
"dropout_sim_03模型",
|
|
"dropout_sim_04模型",
|
|
"t5",
|
|
"t5_dropout",
|
|
"小说模型",
|
|
"yy"]
|
|
).to_csv(
|
|
path_out)
|