You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
5.5 KiB
141 lines
5.5 KiB
# -*- coding = utf-8 -*-
|
|
# @Time: 16:41
|
|
# @Author:ZYP
|
|
# @File:roformer_api.py
|
|
# @mail:zypsunshine1@gmail.com
|
|
# @Software: PyCharm
|
|
|
|
# =========================================================================================
|
|
# 加载深度学习模型
|
|
# · 加载论文分类模型
|
|
# · 加载 BERT 模型
|
|
# =========================================================================================
|
|
|
|
import json
|
|
import os
|
|
import numpy as np
|
|
from bert4keras.models import build_transformer_model
|
|
from keras.layers import Lambda, Dense
|
|
from keras.models import Model
|
|
from bert4keras.tokenizers import Tokenizer
|
|
from bert4keras.backend import K
|
|
import tensorflow as tf
|
|
from keras.backend import set_session
|
|
from flask import Flask, request
|
|
|
|
# =========================================================================================================================
|
|
# roformer 模型的参数
|
|
# =========================================================================================================================
|
|
class_nums = 168
|
|
max_len = 512
|
|
roformer_config_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json'
|
|
roformer_ckpt_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt'
|
|
roformer_vocab_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt'
|
|
roformer_model_weights_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/model3/best_model.weights'
|
|
label_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/data/label_threshold.txt'
|
|
|
|
tfconfig = tf.ConfigProto()
|
|
tfconfig.gpu_options.allow_growth = True
|
|
set_session(tf.Session(config=tfconfig)) # 此处不同
|
|
global graph
|
|
graph = tf.get_default_graph()
|
|
sess = tf.Session(graph=graph)
|
|
set_session(sess)
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
app_roformer = Flask(__name__)
|
|
|
|
|
|
def load_roformer_model(config, ckpt):
|
|
"""加载训练好的168多标签分类模型"""
|
|
roformer = build_transformer_model(
|
|
config_path=config,
|
|
checkpoint_path=ckpt,
|
|
model='roformer_v2',
|
|
return_keras_model=False)
|
|
output = Lambda(lambda x: x[:, 0])(roformer.model.output)
|
|
output = Dense(
|
|
units=class_nums,
|
|
kernel_initializer=roformer.initializer
|
|
)(output)
|
|
model1 = Model(roformer.model.input, output)
|
|
model1.summary()
|
|
return model1
|
|
|
|
|
|
def load_label(label_path1):
|
|
"""加载label2id、id2label、每个类别的阈值,用于分类"""
|
|
with open(label_path1, 'r', encoding='utf-8') as f:
|
|
json_dict = json.load(f)
|
|
label2id1 = {i: j[0] for i, j in json_dict.items()}
|
|
id2label1 = {j[0]: i for i, j in json_dict.items()}
|
|
label_threshold1 = np.array([j[1] for i, j in json_dict.items()])
|
|
return label2id1, id2label1, label_threshold1
|
|
|
|
|
|
# 加载label的相关信息
|
|
label2id, id2label, label_threshold = load_label(label_path)
|
|
# roformer 模型的分词器
|
|
tokenizer_roformer = Tokenizer(roformer_vocab_path)
|
|
|
|
# 加载模型
|
|
model_roformer = load_roformer_model(roformer_config_path, roformer_ckpt_path)
|
|
set_session(sess)
|
|
|
|
# 加载训练好的权重
|
|
model_roformer.load_weights(roformer_model_weights_path)
|
|
|
|
|
|
def encode(text_list1):
|
|
"""将文本列表进行循环编码"""
|
|
sent_token_id1, sent_segment_id1 = [], []
|
|
for index, text in enumerate(text_list1):
|
|
if index == 0:
|
|
token_id, segment_id = tokenizer_roformer.encode(text)
|
|
else:
|
|
token_id, segment_id = tokenizer_roformer.encode(text)
|
|
token_id = token_id[1:]
|
|
segment_id = segment_id[1:]
|
|
if (index + 1) % 2 == 0:
|
|
segment_id = [1] * len(token_id)
|
|
sent_token_id1 += token_id
|
|
sent_segment_id1 += segment_id
|
|
if len(sent_token_id1) > max_len:
|
|
sent_token_id1 = sent_token_id1[:max_len]
|
|
sent_segment_id1 = sent_segment_id1[:max_len]
|
|
|
|
sent_token_id = np.array([sent_token_id1])
|
|
sent_segment_id = np.array([sent_segment_id1])
|
|
return sent_token_id, sent_segment_id
|
|
|
|
|
|
@app_roformer.route('/roformer', methods=['POST'])
|
|
def pred_class_num():
|
|
"""将分类的预测结果进行返回,返回对应库的下标,同时对送检论文的要求处理成字典形式,包括 title、key_words、abst_zh、content 等"""
|
|
try:
|
|
target_paper_dict = json.loads(request.data.decode())
|
|
text_list1 = [target_paper_dict['title']] # , target_paper_dict['key_words']
|
|
abst_zh = target_paper_dict['abst_zh']
|
|
|
|
if len(abst_zh.split("。")) <= 10:
|
|
text_list1.append(abst_zh)
|
|
else:
|
|
text_list1.append("。".join(abst_zh.split('。')[:5]))
|
|
text_list1.append("。".join(abst_zh.split('。')[-5:]))
|
|
|
|
sent_token, segment_ids = encode(text_list1)
|
|
|
|
with graph.as_default():
|
|
K.set_session(sess)
|
|
y_pred = model_roformer.predict([sent_token, segment_ids])
|
|
|
|
idx = np.where(y_pred[0] > label_threshold, 1, 0)
|
|
pred_label_num_dict = {'label_num': [index for index, i in enumerate(idx) if i == 1]}
|
|
return json.dumps(pred_label_num_dict)
|
|
except:
|
|
return 'error_roformer'
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# app_roformer.run('0.0.0.0', port=50003, debug=False)
|
|
|