You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
6.7 KiB
209 lines
6.7 KiB
![]()
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
"""
|
||
|
@Time : 2023/1/16 14:59
|
||
|
@Author :
|
||
|
@FileName:
|
||
|
@Software:
|
||
|
@Describe:
|
||
|
"""
|
||
|
#! -*- coding: utf-8 -*-
|
||
|
|
||
|
import os
|
||
|
# os.environ["TF_KERAS"] = "1"
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
import glob
|
||
|
from numpy import random
|
||
|
random.seed(1001)
|
||
|
from tqdm import tqdm
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import json
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
from bert4keras.backend import keras, K
|
||
|
from bert4keras.layers import Loss
|
||
|
from bert4keras.models import build_transformer_model
|
||
|
from bert4keras.tokenizers import SpTokenizer
|
||
|
from bert4keras.optimizers import Adam
|
||
|
from bert4keras.snippets import sequence_padding, open
|
||
|
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
|
||
|
from keras.models import Model
|
||
|
# from rouge import Rouge # pip install rouge
|
||
|
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
from keras.backend import set_session
|
||
|
config = tf.ConfigProto()
|
||
|
config.gpu_options.allow_growth = True
|
||
|
set_session(tf.Session(config=config)) # 此处不同
|
||
|
global graph
|
||
|
graph = tf.get_default_graph()
|
||
|
sess = tf.Session(graph=graph)
|
||
|
set_session(sess)
|
||
|
|
||
|
# global graph,model
|
||
|
# graph = tf.get_default_graph()
|
||
|
# sess = tf.Session(graph=graph)
|
||
|
# K.set_session(sess)
|
||
|
|
||
|
|
||
|
# 基本参数
|
||
|
|
||
|
class GenerateModel(object):
|
||
|
def __init__(self):
|
||
|
|
||
|
self.epoch_acc_vel = 0
|
||
|
self.config_path = 'mt5/mt5_base/mt5_base_config.json'
|
||
|
self.checkpoint_path = 'mt5/mt5_base/model.ckpt-1000000'
|
||
|
self.spm_path = 'mt5/mt5_base/sentencepiece_cn.model'
|
||
|
self.keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json'
|
||
|
self.maxlen = 256
|
||
|
|
||
|
def device_setup(self):
|
||
|
tokenizer = SpTokenizer(self.spm_path, token_start=None, token_end='</s>')
|
||
|
keep_tokens = json.load(open(self.keep_tokens_path))
|
||
|
|
||
|
t5 = build_transformer_model(
|
||
|
config_path=self.config_path,
|
||
|
checkpoint_path=self.checkpoint_path,
|
||
|
keep_tokens=keep_tokens,
|
||
|
model='mt5.1.1',
|
||
|
return_keras_model=False,
|
||
|
name='T5',
|
||
|
)
|
||
|
|
||
|
# output = CrossEntropy(2)(model.inputs + model.outputs)
|
||
|
#
|
||
|
# model = Model(model.inputs, output)
|
||
|
encoder = t5.encoder
|
||
|
decoder = t5.decoder
|
||
|
model = t5.model
|
||
|
model.summary()
|
||
|
|
||
|
output = CrossEntropy(1)([model.inputs[1], model.outputs[0]])
|
||
|
|
||
|
model = Model(model.inputs, output)
|
||
|
path_model = "output_t5/best_model_t5_dropout_0_3.weights"
|
||
|
model.load_weights(path_model)
|
||
|
|
||
|
return encoder, decoder, model, tokenizer
|
||
|
|
||
|
|
||
|
class CrossEntropy(Loss):
|
||
|
"""交叉熵作为loss,并mask掉输入部分
|
||
|
"""
|
||
|
|
||
|
def compute_loss(self, inputs, mask=None):
|
||
|
y_true, y_pred = inputs
|
||
|
y_true = y_true[:, 1:] # 目标token_ids
|
||
|
y_mask = K.cast(mask[1], K.floatx())[:, 1:] # 解码器自带mask
|
||
|
y_pred = y_pred[:, :-1] # 预测序列,错开一位
|
||
|
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
|
||
|
loss = K.sum(loss * y_mask) / K.sum(y_mask)
|
||
|
return loss
|
||
|
|
||
|
class AutoTitle(AutoRegressiveDecoder):
|
||
|
"""seq2seq解码器
|
||
|
"""
|
||
|
def __init__(self, encoder, decoder, model, tokenizer, start_id, end_id, maxlen, minlen=1):
|
||
|
super(AutoTitle, self).__init__(start_id, end_id, maxlen, minlen)
|
||
|
self.encoder = encoder
|
||
|
self.decoder = decoder
|
||
|
self.model = model
|
||
|
self.tokenizer = tokenizer
|
||
|
self.start_id = start_id
|
||
|
self.end_id = end_id
|
||
|
self.minlen = minlen
|
||
|
self.models = {}
|
||
|
if start_id is None:
|
||
|
self.first_output_ids = np.empty((1, 0), dtype=int)
|
||
|
else:
|
||
|
self.first_output_ids = np.array([[self.start_id]])
|
||
|
|
||
|
# @AutoRegressiveDecoder.wraps(default_rtype='probas')
|
||
|
# def predict(self, inputs, output_ids, states):
|
||
|
# token_ids, segment_ids = inputs
|
||
|
# token_ids = np.concatenate([token_ids, output_ids], 1)
|
||
|
# segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
|
||
|
# with graph.as_default():
|
||
|
# K.set_session(sess)
|
||
|
# nodes = self.last_token(self.model).predict([token_ids, segment_ids])
|
||
|
# return nodes
|
||
|
# # return self.last_token(self.model).predict([token_ids, segment_ids])
|
||
|
|
||
|
# @AutoRegressiveDecoder.wraps(default_rtype='probas')
|
||
|
# def predict(self, inputs, output_ids, states):
|
||
|
# c_encoded = inputs[0]
|
||
|
# with graph.as_default():
|
||
|
# K.set_session(sess)
|
||
|
# nodes = self.last_token(self.decoder).predict([c_encoded, output_ids])
|
||
|
# return nodes
|
||
|
|
||
|
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
||
|
def predict(self, inputs, output_ids, states):
|
||
|
c_encoded = inputs[0]
|
||
|
with graph.as_default():
|
||
|
K.set_session(sess)
|
||
|
nodes = self.last_token(decoder).predict([c_encoded, output_ids])
|
||
|
return nodes
|
||
|
|
||
|
|
||
|
|
||
|
def generate(self, text, topk=3):
|
||
|
c_token_ids, _ = tokenizer.encode(text, maxlen=120)
|
||
|
c_encoded = encoder.predict(np.array([c_token_ids]))[0]
|
||
|
output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search
|
||
|
return tokenizer.decode([int(i) for i in output_ids])
|
||
|
|
||
|
def generate_random(self, text, n=30, topp=0.9):
|
||
|
c_token_ids, _ = self.tokenizer.encode(text, maxlen=120)
|
||
|
with graph.as_default():
|
||
|
K.set_session(sess)
|
||
|
c_encoded = self.encoder.predict(np.array([c_token_ids]))[0]
|
||
|
output_ids = self.random_sample([c_encoded], n, topp=topp) # 基于随机采样
|
||
|
text = []
|
||
|
for ids in output_ids:
|
||
|
text.append(tokenizer.decode([int(i) for i in ids]))
|
||
|
return text
|
||
|
|
||
|
|
||
|
|
||
|
generatemodel = GenerateModel()
|
||
|
encoder, decoder, model, tokenizer = generatemodel.device_setup()
|
||
|
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120)
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
def just_show_sentence(file):
|
||
|
"""
|
||
|
@param file:list
|
||
|
"""
|
||
|
text = file[0]
|
||
|
pre = autotitle.generate(text)
|
||
|
return pre
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# file = "train_2842.txt"
|
||
|
# just_show(file)
|
||
|
# text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是自身发展的必然选择,是时代对于青年们的深切呼唤"]
|
||
|
# a = just_show_sentence(text)
|
||
|
# print(a)
|
||
|
# print(type(a))
|
||
|
is_novel = False
|
||
|
path = "./data/700条论文测试.xlsx"
|
||
|
df_list = pd.read_excel(path).values.tolist()
|
||
|
|
||
|
|
||
|
df_list_new = []
|
||
|
print(len(df_list))
|
||
|
for i in tqdm(df_list):
|
||
|
pre = just_show_sentence([i[0]])
|
||
|
|
||
|
df_list_new.append([i[0], i[1], pre])
|
||
|
|
||
|
df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "t5模型"])
|
||
|
df.to_excel("./data/700条论文测试_7.xlsx", index=None)
|