You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
7.2 KiB
217 lines
7.2 KiB
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@Time : 2022/9/19 14:43
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
|
|
# 训练环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.7.7
|
|
|
|
import os
|
|
# os.environ["TF_KERAS"] = "1"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
import json
|
|
import numpy as np
|
|
from collections import Counter
|
|
from bert4keras.backend import keras, K
|
|
from bert4keras.layers import Loss
|
|
from bert4keras.models import build_transformer_model
|
|
from bert4keras.tokenizers import Tokenizer, load_vocab
|
|
from bert4keras.optimizers import Adam, extend_with_weight_decay
|
|
from bert4keras.snippets import DataGenerator
|
|
from bert4keras.snippets import sequence_padding
|
|
from bert4keras.snippets import text_segmentate
|
|
from bert4keras.snippets import AutoRegressiveDecoder
|
|
# from bert4keras.snippets import uniout
|
|
import tensorflow as tf
|
|
from keras.backend import set_session
|
|
config = tf.ConfigProto()
|
|
config.gpu_options.allow_growth = True
|
|
set_session(tf.Session(config=config)) # 此处不同
|
|
|
|
# 基本信息
|
|
maxlen = 256
|
|
batch_size = 32
|
|
steps_per_epoch = 40000
|
|
epochs = 10000
|
|
|
|
|
|
# bert配置
|
|
config_path = './bert_config_dropout_0_3.json'
|
|
checkpoint_path = './chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
|
|
dict_path = './chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab_drop.txt'
|
|
|
|
|
|
file = "data/train_yy_sim.txt"
|
|
try:
|
|
with open(file, 'r', encoding="utf-8") as f:
|
|
lines = [x.strip() for x in f if x.strip() != '']
|
|
except:
|
|
with open(file, 'r', encoding="gbk") as f:
|
|
lines = [x.strip() for x in f if x.strip() != '']
|
|
|
|
|
|
# 加载并精简词表,建立分词器
|
|
token_dict, keep_tokens = load_vocab(
|
|
dict_path=dict_path,
|
|
simplified=True,
|
|
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
|
|
)
|
|
tokenizer = Tokenizer(token_dict, do_lower_case=True)
|
|
|
|
|
|
class data_generator(DataGenerator):
|
|
"""数据生成器
|
|
"""
|
|
def __iter__(self, random=False):
|
|
batch_token_ids, batch_segment_ids = [], []
|
|
for is_end, txt in self.sample(random):
|
|
text = txt.split('\t')
|
|
if len(text) == 3:
|
|
content = text[0]
|
|
content_g = text[2]
|
|
token_ids, segment_ids = tokenizer.encode(
|
|
content, content_g, maxlen=maxlen
|
|
)
|
|
batch_token_ids.append(token_ids)
|
|
batch_segment_ids.append(segment_ids)
|
|
if len(batch_token_ids) == self.batch_size or is_end:
|
|
batch_token_ids = sequence_padding(batch_token_ids)
|
|
batch_segment_ids = sequence_padding(batch_segment_ids)
|
|
yield [batch_token_ids, batch_segment_ids], None
|
|
batch_token_ids, batch_segment_ids = [], []
|
|
|
|
|
|
class TotalLoss(Loss):
|
|
"""loss分两部分,一是seq2seq的交叉熵,二是相似度的交叉熵。
|
|
"""
|
|
def compute_loss(self, inputs, mask=None):
|
|
loss1 = self.compute_loss_of_seq2seq(inputs, mask)
|
|
loss2 = self.compute_loss_of_similarity(inputs, mask)
|
|
self.add_metric(loss1, name='seq2seq_loss')
|
|
self.add_metric(loss2, name='similarity_loss')
|
|
return loss1 + loss2
|
|
|
|
def compute_loss_of_seq2seq(self, inputs, mask=None):
|
|
y_true, y_mask, _, y_pred = inputs
|
|
y_true = y_true[:, 1:] # 目标token_ids
|
|
y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分
|
|
y_pred = y_pred[:, :-1] # 预测序列,错开一位
|
|
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
|
|
loss = K.sum(loss * y_mask) / K.sum(y_mask)
|
|
return loss
|
|
|
|
def compute_loss_of_similarity(self, inputs, mask=None):
|
|
_, _, y_pred, _ = inputs
|
|
y_true = self.get_labels_of_similarity(y_pred) # 构建标签
|
|
y_pred = K.l2_normalize(y_pred, axis=1) # 句向量归一化
|
|
similarities = K.dot(y_pred, K.transpose(y_pred)) # 相似度矩阵
|
|
similarities = similarities - K.eye(K.shape(y_pred)[0]) * 1e12 # 排除对角线
|
|
similarities = similarities * 30 # scale
|
|
loss = K.categorical_crossentropy(
|
|
y_true, similarities, from_logits=True
|
|
)
|
|
return loss
|
|
|
|
def get_labels_of_similarity(self, y_pred):
|
|
idxs = K.arange(0, K.shape(y_pred)[0])
|
|
idxs_1 = idxs[None, :]
|
|
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
|
|
labels = K.equal(idxs_1, idxs_2)
|
|
labels = K.cast(labels, K.floatx())
|
|
return labels
|
|
|
|
|
|
# 建立加载模型
|
|
bert = build_transformer_model(
|
|
config_path,
|
|
checkpoint_path,
|
|
with_pool='linear',
|
|
application='unilm',
|
|
keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表
|
|
return_keras_model=False,
|
|
ignore_invalid_weights=True
|
|
)
|
|
|
|
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
|
|
seq2seq = keras.models.Model(bert.model.inputs, bert.model.outputs[1])
|
|
|
|
outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs)
|
|
model = keras.models.Model(bert.model.inputs, outputs)
|
|
|
|
AdamW = extend_with_weight_decay(Adam, 'AdamW')
|
|
optimizer = AdamW(learning_rate=2e-6, weight_decay_rate=0.01)
|
|
model.compile(optimizer=optimizer)
|
|
model.summary()
|
|
|
|
|
|
class AutoTitle(AutoRegressiveDecoder):
|
|
"""seq2seq解码器
|
|
"""
|
|
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
|
def predict(self, inputs, output_ids, states):
|
|
token_ids, segment_ids = inputs
|
|
token_ids = np.concatenate([token_ids, output_ids], 1)
|
|
segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
|
|
return self.last_token(model).predict([token_ids, segment_ids])
|
|
|
|
def generate(self, text, n=1, topk=5):
|
|
token_ids, segment_ids = tokenizer.encode(text, maxlen= maxlen)
|
|
output_ids = self.random_sample([token_ids, segment_ids], n,
|
|
topk) # 基于随机采样
|
|
return [tokenizer.decode(ids) for ids in output_ids]
|
|
|
|
def generate_(self, text, topk=1):
|
|
max_c_len = maxlen - self.maxlen
|
|
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
|
|
output_ids = self.beam_search([token_ids, segment_ids],
|
|
topk=topk) # 基于beam search
|
|
return tokenizer.decode(output_ids)
|
|
|
|
|
|
autotitle = AutoTitle(start_id=None, end_id=tokenizer._token_end_id, maxlen=120)
|
|
|
|
|
|
def just_show():
|
|
|
|
s2 = u'尽管是有些疑惑,但大家也只敢是脸上带着笑意,慢慢地从苏溪的嘴里面套一些话出来。'
|
|
for s in [s2]:
|
|
print(u'生成:', autotitle.generate(s))
|
|
print()
|
|
|
|
|
|
class Evaluate(keras.callbacks.Callback):
|
|
"""评估模型
|
|
"""
|
|
def __init__(self):
|
|
self.lowest = 1e10
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
model.save_weights('./output_simbert_yy/latest_simbertmodel_dropout_datasim_yinhao.weights')
|
|
# 保存最优
|
|
if logs['loss'] <= self.lowest:
|
|
self.lowest = logs['loss']
|
|
model.save_weights('./output_simbert_yy/best_simbertmodel_dropout_datasim_yinhao.weights')
|
|
# 演示效果
|
|
# just_show()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
train_generator = data_generator(lines, batch_size)
|
|
evaluator = Evaluate()
|
|
|
|
model.fit_generator(
|
|
train_generator.forfit(),
|
|
steps_per_epoch=steps_per_epoch,
|
|
epochs=epochs,
|
|
callbacks=[evaluator]
|
|
)
|
|
|
|
# else:
|
|
#
|
|
# model.load_weights('./latest_model.weights')
|
|
|
|
|