普通版降重
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
7.4 KiB

# -*- coding: utf-8 -*-
"""
@Time : 2022/8/5 13:41
@Author :
@FileName:
@Software:
@Describe:
"""
# ! -*- coding: utf-8 -*-
# 微调多国语言版T5做Seq2Seq任务
# 介绍链接:https://kexue.fm/archives/7867
# 细节请看:https://github.com/bojone/t5_in_bert4keras
# 数据集:https://github.com/CLUEbenchmark/CLGE 中的CSL数据集
# 补充了评测指标bleu、rouge-1、rouge-2、rouge-l
import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import numpy as np
from tqdm import tqdm
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import SpTokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, open
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model
# from rouge import Rouge # pip install rouge
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# 基本参数
max_c_len = 128
max_t_len = 128
batch_size = 28
epochs = 10000
# 模型路径
config_path = 'mt5/mt5_base_dropout_0_3_config.json'
checkpoint_path = 'mt5/mt5_base/model.ckpt-1000000'
spm_path = 'mt5/mt5_base/sentencepiece_cn.model'
keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json'
file = "data/train_yy_zong_sim_99.txt"
try:
with open(file, 'r', encoding="utf-8") as f:
lines = [x.strip() for x in f if x.strip() != '']
except:
with open(file, 'r', encoding="gbk") as f:
lines = [x.strip() for x in f if x.strip() != '']
# 加载分词器
tokenizer = SpTokenizer(spm_path, token_start=None, token_end='</s>')
keep_tokens = json.load(open(keep_tokens_path))
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_c_token_ids, batch_t_token_ids = [], []
for is_end, txt in self.sample(random):
text = txt.split('\t')
if len(text) == 3:
content = text[0]
content_g = text[2]
c_token_ids, _ = tokenizer.encode(content, maxlen=max_c_len)
t_token_ids, _ = tokenizer.encode(content_g, maxlen=max_t_len)
# token_ids, segment_ids = tokenizer.encode(
# content, content_g, maxlen=max_c_len
# )
batch_c_token_ids.append(c_token_ids)
batch_t_token_ids.append([0] + t_token_ids)
if len(batch_c_token_ids) == self.batch_size or is_end:
batch_c_token_ids = sequence_padding(batch_c_token_ids)
batch_t_token_ids = sequence_padding(batch_t_token_ids)
yield [batch_c_token_ids, batch_t_token_ids], None
batch_c_token_ids, batch_t_token_ids = [], []
class CrossEntropy(Loss):
"""交叉熵作为loss,并mask掉输入部分
"""
def compute_loss(self, inputs, mask=None):
y_true, y_pred = inputs
y_true = y_true[:, 1:] # 目标token_ids
y_mask = K.cast(mask[1], K.floatx())[:, 1:] # 解码器自带mask
y_pred = y_pred[:, :-1] # 预测序列,错开一位
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
loss = K.sum(loss * y_mask) / K.sum(y_mask)
return loss
t5 = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
keep_tokens=keep_tokens,
model='mt5.1.1',
return_keras_model=False,
name='T5',
)
encoder = t5.encoder
decoder = t5.decoder
model = t5.model
model.summary()
output = CrossEntropy(1)([model.inputs[1], model.outputs[0]])
model = Model(model.inputs, output)
model.compile(optimizer=Adam(2e-4))
# path_model = "output_t5/best_model_t5_juben.weights"
# model.load_weights(path_model)
class AutoTitle(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
c_encoded = inputs[0]
return self.last_token(decoder).predict([c_encoded, output_ids])
def generate(self, text, topk=1):
c_token_ids, _ = tokenizer.encode(text, maxlen=max_c_len)
c_encoded = encoder.predict(np.array([c_token_ids]))[0]
output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search
return tokenizer.decode([int(i) for i in output_ids])
# 注:T5有一个很让人不解的设置,它的<bos>标记id是0,即<bos>和<pad>其实都是0
autotitle = AutoTitle(
start_id=0, end_id=tokenizer._token_end_id, maxlen=max_t_len
)
# class Evaluator(keras.callbacks.Callback):
# """评估与保存
# """
#
# def __init__(self):
# self.rouge = Rouge()
# self.smooth = SmoothingFunction().method1
# self.best_bleu = 0.
#
# def on_epoch_end(self, epoch, logs=None):
# metrics = self.evaluate(valid_data) # 评测模型
# if metrics['bleu'] > self.best_bleu:
# self.best_bleu = metrics['bleu']
# model.save_weights('./best_model.weights') # 保存模型
# metrics['best_bleu'] = self.best_bleu
# print('valid_data:', metrics)
#
# def evaluate(self, data, topk=1):
# total = 0
# rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
# for title, content in tqdm(data):
# total += 1
# title = ' '.join(title).lower()
# pred_title = ' '.join(autotitle.generate(content,
# topk=topk)).lower()
# if pred_title.strip():
# scores = self.rouge.get_scores(hyps=pred_title, refs=title)
# rouge_1 += scores[0]['rouge-1']['f']
# rouge_2 += scores[0]['rouge-2']['f']
# rouge_l += scores[0]['rouge-l']['f']
# bleu += sentence_bleu(
# references=[title.split(' ')],
# hypothesis=pred_title.split(' '),
# smoothing_function=self.smooth
# )
# rouge_1 /= total
# rouge_2 /= total
# rouge_l /= total
# bleu /= total
# return {
# 'rouge-1': rouge_1,
# 'rouge-2': rouge_2,
# 'rouge-l': rouge_l,
# 'bleu': bleu,
# }
def just_show():
s2 = "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"
for s in [s2]:
print(u'生成标题:', autotitle.generate(s))
print()
class Evaluator(keras.callbacks.Callback):
"""评估与保存
"""
def __init__(self):
self.lowest = 1e10
def on_epoch_end(self, epoch, logs=None):
# 保存最优
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save_weights('./output_t5/best_model_t5_zong_sim_99.weights')
# 演示效果7
just_show()
if __name__ == '__main__':
evaluator = Evaluator()
train_generator = data_generator(lines, batch_size)
model.fit(
train_generator.forfit(),
steps_per_epoch=len(train_generator),
epochs=epochs,
callbacks=[evaluator]
)
else:
model.load_weights('./best_model.weights')