You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
7.4 KiB
227 lines
7.4 KiB
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@Time : 2022/8/5 13:41
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
|
|
# ! -*- coding: utf-8 -*-
|
|
# 微调多国语言版T5做Seq2Seq任务
|
|
# 介绍链接:https://kexue.fm/archives/7867
|
|
# 细节请看:https://github.com/bojone/t5_in_bert4keras
|
|
# 数据集:https://github.com/CLUEbenchmark/CLGE 中的CSL数据集
|
|
# 补充了评测指标bleu、rouge-1、rouge-2、rouge-l
|
|
import os
|
|
# os.environ["TF_KERAS"] = "1"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
import json
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from bert4keras.backend import keras, K
|
|
from bert4keras.layers import Loss
|
|
from bert4keras.models import build_transformer_model
|
|
from bert4keras.tokenizers import SpTokenizer
|
|
from bert4keras.optimizers import Adam
|
|
from bert4keras.snippets import sequence_padding, open
|
|
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
|
|
from keras.models import Model
|
|
# from rouge import Rouge # pip install rouge
|
|
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
|
import tensorflow as tf
|
|
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
|
|
|
|
# 基本参数
|
|
max_c_len = 128
|
|
max_t_len = 128
|
|
batch_size = 28
|
|
epochs = 10000
|
|
|
|
# 模型路径
|
|
config_path = 'mt5/mt5_base_dropout_0_3_config.json'
|
|
checkpoint_path = 'mt5/mt5_base/model.ckpt-1000000'
|
|
spm_path = 'mt5/mt5_base/sentencepiece_cn.model'
|
|
keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json'
|
|
|
|
|
|
file = "data/train_yy_zong_sim_99.txt"
|
|
try:
|
|
with open(file, 'r', encoding="utf-8") as f:
|
|
lines = [x.strip() for x in f if x.strip() != '']
|
|
except:
|
|
with open(file, 'r', encoding="gbk") as f:
|
|
lines = [x.strip() for x in f if x.strip() != '']
|
|
# 加载分词器
|
|
tokenizer = SpTokenizer(spm_path, token_start=None, token_end='</s>')
|
|
keep_tokens = json.load(open(keep_tokens_path))
|
|
|
|
|
|
class data_generator(DataGenerator):
|
|
"""数据生成器
|
|
"""
|
|
def __iter__(self, random=False):
|
|
batch_c_token_ids, batch_t_token_ids = [], []
|
|
for is_end, txt in self.sample(random):
|
|
text = txt.split('\t')
|
|
if len(text) == 3:
|
|
content = text[0]
|
|
content_g = text[2]
|
|
c_token_ids, _ = tokenizer.encode(content, maxlen=max_c_len)
|
|
t_token_ids, _ = tokenizer.encode(content_g, maxlen=max_t_len)
|
|
# token_ids, segment_ids = tokenizer.encode(
|
|
# content, content_g, maxlen=max_c_len
|
|
# )
|
|
batch_c_token_ids.append(c_token_ids)
|
|
batch_t_token_ids.append([0] + t_token_ids)
|
|
if len(batch_c_token_ids) == self.batch_size or is_end:
|
|
batch_c_token_ids = sequence_padding(batch_c_token_ids)
|
|
batch_t_token_ids = sequence_padding(batch_t_token_ids)
|
|
yield [batch_c_token_ids, batch_t_token_ids], None
|
|
batch_c_token_ids, batch_t_token_ids = [], []
|
|
|
|
|
|
class CrossEntropy(Loss):
|
|
"""交叉熵作为loss,并mask掉输入部分
|
|
"""
|
|
def compute_loss(self, inputs, mask=None):
|
|
y_true, y_pred = inputs
|
|
y_true = y_true[:, 1:] # 目标token_ids
|
|
y_mask = K.cast(mask[1], K.floatx())[:, 1:] # 解码器自带mask
|
|
y_pred = y_pred[:, :-1] # 预测序列,错开一位
|
|
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
|
|
loss = K.sum(loss * y_mask) / K.sum(y_mask)
|
|
return loss
|
|
|
|
|
|
t5 = build_transformer_model(
|
|
config_path=config_path,
|
|
checkpoint_path=checkpoint_path,
|
|
keep_tokens=keep_tokens,
|
|
model='mt5.1.1',
|
|
return_keras_model=False,
|
|
name='T5',
|
|
)
|
|
|
|
encoder = t5.encoder
|
|
decoder = t5.decoder
|
|
model = t5.model
|
|
model.summary()
|
|
|
|
output = CrossEntropy(1)([model.inputs[1], model.outputs[0]])
|
|
|
|
model = Model(model.inputs, output)
|
|
model.compile(optimizer=Adam(2e-4))
|
|
# path_model = "output_t5/best_model_t5_juben.weights"
|
|
# model.load_weights(path_model)
|
|
|
|
class AutoTitle(AutoRegressiveDecoder):
|
|
"""seq2seq解码器
|
|
"""
|
|
|
|
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
|
def predict(self, inputs, output_ids, states):
|
|
c_encoded = inputs[0]
|
|
return self.last_token(decoder).predict([c_encoded, output_ids])
|
|
|
|
def generate(self, text, topk=1):
|
|
c_token_ids, _ = tokenizer.encode(text, maxlen=max_c_len)
|
|
c_encoded = encoder.predict(np.array([c_token_ids]))[0]
|
|
output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search
|
|
return tokenizer.decode([int(i) for i in output_ids])
|
|
|
|
|
|
# 注:T5有一个很让人不解的设置,它的<bos>标记id是0,即<bos>和<pad>其实都是0
|
|
autotitle = AutoTitle(
|
|
start_id=0, end_id=tokenizer._token_end_id, maxlen=max_t_len
|
|
)
|
|
|
|
|
|
# class Evaluator(keras.callbacks.Callback):
|
|
# """评估与保存
|
|
# """
|
|
#
|
|
# def __init__(self):
|
|
# self.rouge = Rouge()
|
|
# self.smooth = SmoothingFunction().method1
|
|
# self.best_bleu = 0.
|
|
#
|
|
# def on_epoch_end(self, epoch, logs=None):
|
|
# metrics = self.evaluate(valid_data) # 评测模型
|
|
# if metrics['bleu'] > self.best_bleu:
|
|
# self.best_bleu = metrics['bleu']
|
|
# model.save_weights('./best_model.weights') # 保存模型
|
|
# metrics['best_bleu'] = self.best_bleu
|
|
# print('valid_data:', metrics)
|
|
#
|
|
# def evaluate(self, data, topk=1):
|
|
# total = 0
|
|
# rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
|
|
# for title, content in tqdm(data):
|
|
# total += 1
|
|
# title = ' '.join(title).lower()
|
|
# pred_title = ' '.join(autotitle.generate(content,
|
|
# topk=topk)).lower()
|
|
# if pred_title.strip():
|
|
# scores = self.rouge.get_scores(hyps=pred_title, refs=title)
|
|
# rouge_1 += scores[0]['rouge-1']['f']
|
|
# rouge_2 += scores[0]['rouge-2']['f']
|
|
# rouge_l += scores[0]['rouge-l']['f']
|
|
# bleu += sentence_bleu(
|
|
# references=[title.split(' ')],
|
|
# hypothesis=pred_title.split(' '),
|
|
# smoothing_function=self.smooth
|
|
# )
|
|
# rouge_1 /= total
|
|
# rouge_2 /= total
|
|
# rouge_l /= total
|
|
# bleu /= total
|
|
# return {
|
|
# 'rouge-1': rouge_1,
|
|
# 'rouge-2': rouge_2,
|
|
# 'rouge-l': rouge_l,
|
|
# 'bleu': bleu,
|
|
# }
|
|
|
|
def just_show():
|
|
|
|
s2 = "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"
|
|
for s in [s2]:
|
|
print(u'生成标题:', autotitle.generate(s))
|
|
print()
|
|
|
|
|
|
class Evaluator(keras.callbacks.Callback):
|
|
"""评估与保存
|
|
"""
|
|
def __init__(self):
|
|
self.lowest = 1e10
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
# 保存最优
|
|
if logs['loss'] <= self.lowest:
|
|
self.lowest = logs['loss']
|
|
model.save_weights('./output_t5/best_model_t5_zong_sim_99.weights')
|
|
# 演示效果7
|
|
just_show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
evaluator = Evaluator()
|
|
train_generator = data_generator(lines, batch_size)
|
|
|
|
model.fit(
|
|
train_generator.forfit(),
|
|
steps_per_epoch=len(train_generator),
|
|
epochs=epochs,
|
|
callbacks=[evaluator]
|
|
)
|
|
|
|
else:
|
|
model.load_weights('./best_model.weights')
|
|
|
|
|