# -*- coding = utf-8 -*- # @Time: 16:41 # @Author:ZYP # @File:roformer_api.py # @mail:zypsunshine1@gmail.com # @Software: PyCharm # ========================================================================================= # 加载深度学习模型 # · 加载论文分类模型 # · 加载 BERT 模型 # ========================================================================================= import json import os import numpy as np from bert4keras.models import build_transformer_model from keras.layers import Lambda, Dense from keras.models import Model from bert4keras.tokenizers import Tokenizer from bert4keras.backend import K import tensorflow as tf from keras.backend import set_session from flask import Flask, request # ========================================================================================================================= # roformer 模型的参数 # ========================================================================================================================= class_nums = 168 max_len = 512 roformer_config_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json' roformer_ckpt_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt' roformer_vocab_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt' roformer_model_weights_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/model/model3/best_model.weights' label_path = '/home/zc-nlp-zyp/work_file/ssd_data/program/zhiwang_VSM/class_analysis/max_class_train/data/label_threshold.txt' tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True set_session(tf.Session(config=tfconfig)) # 此处不同 global graph graph = tf.get_default_graph() sess = tf.Session(graph=graph) set_session(sess) os.environ["CUDA_VISIBLE_DEVICES"] = "0" app_roformer = Flask(__name__) def load_roformer_model(config, ckpt): """加载训练好的168多标签分类模型""" roformer = build_transformer_model( config_path=config, checkpoint_path=ckpt, model='roformer_v2', return_keras_model=False) output = Lambda(lambda x: x[:, 0])(roformer.model.output) output = Dense( units=class_nums, kernel_initializer=roformer.initializer )(output) model1 = Model(roformer.model.input, output) model1.summary() return model1 def load_label(label_path1): """加载label2id、id2label、每个类别的阈值,用于分类""" with open(label_path1, 'r', encoding='utf-8') as f: json_dict = json.load(f) label2id1 = {i: j[0] for i, j in json_dict.items()} id2label1 = {j[0]: i for i, j in json_dict.items()} label_threshold1 = np.array([j[1] for i, j in json_dict.items()]) return label2id1, id2label1, label_threshold1 # 加载label的相关信息 label2id, id2label, label_threshold = load_label(label_path) # roformer 模型的分词器 tokenizer_roformer = Tokenizer(roformer_vocab_path) # 加载模型 model_roformer = load_roformer_model(roformer_config_path, roformer_ckpt_path) set_session(sess) # 加载训练好的权重 model_roformer.load_weights(roformer_model_weights_path) def encode(text_list1): """将文本列表进行循环编码""" sent_token_id1, sent_segment_id1 = [], [] for index, text in enumerate(text_list1): if index == 0: token_id, segment_id = tokenizer_roformer.encode(text) else: token_id, segment_id = tokenizer_roformer.encode(text) token_id = token_id[1:] segment_id = segment_id[1:] if (index + 1) % 2 == 0: segment_id = [1] * len(token_id) sent_token_id1 += token_id sent_segment_id1 += segment_id if len(sent_token_id1) > max_len: sent_token_id1 = sent_token_id1[:max_len] sent_segment_id1 = sent_segment_id1[:max_len] sent_token_id = np.array([sent_token_id1]) sent_segment_id = np.array([sent_segment_id1]) return sent_token_id, sent_segment_id @app_roformer.route('/roformer', methods=['POST']) def pred_class_num(): """将分类的预测结果进行返回,返回对应库的下标,同时对送检论文的要求处理成字典形式,包括 title、key_words、abst_zh、content 等""" try: target_paper_dict = json.loads(request.data.decode()) text_list1 = [target_paper_dict['title']] # , target_paper_dict['key_words'] abst_zh = target_paper_dict['abst_zh'] if len(abst_zh.split("。")) <= 10: text_list1.append(abst_zh) else: text_list1.append("。".join(abst_zh.split('。')[:5])) text_list1.append("。".join(abst_zh.split('。')[-5:])) sent_token, segment_ids = encode(text_list1) with graph.as_default(): K.set_session(sess) y_pred = model_roformer.predict([sent_token, segment_ids]) idx = np.where(y_pred[0] > label_threshold, 1, 0) pred_label_num_dict = {'label_num': [index for index, i in enumerate(idx) if i == 1]} return json.dumps(pred_label_num_dict) except: return 'error_roformer' # if __name__ == '__main__': # app_roformer.run('0.0.0.0', port=50003, debug=False)