|
|
|
# -*- coding:utf-8 -*-
|
|
|
|
|
|
|
|
import os
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
import json
|
|
|
|
import random
|
|
|
|
import keras
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
from bert4keras.backend import multilabel_categorical_crossentropy
|
|
|
|
from bert4keras.models import build_transformer_model
|
|
|
|
from bert4keras.optimizers import Adam
|
|
|
|
from bert4keras.snippets import DataGenerator, sequence_padding
|
|
|
|
from keras.layers import Lambda, Dense
|
|
|
|
from keras.models import Model
|
|
|
|
from bert4keras.tokenizers import Tokenizer
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
from keras.backend import set_session
|
|
|
|
config = tf.ConfigProto()
|
|
|
|
config.gpu_options.allow_growth = True
|
|
|
|
set_session(tf.Session(config=config)) # 此处不同
|
|
|
|
|
|
|
|
config_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json'
|
|
|
|
checkpoint_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt'
|
|
|
|
dict_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt'
|
|
|
|
|
|
|
|
|
|
|
|
class_nums = 168
|
|
|
|
batch_size = 16
|
|
|
|
max_len = 512
|
|
|
|
|
|
|
|
config_lable = './config_json/label_threshold.json'
|
|
|
|
weight_path = './output_models/best_model.weights'
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = Tokenizer(token_dict=dict_path)
|
|
|
|
|
|
|
|
roformer = build_transformer_model(
|
|
|
|
config_path=config_path,
|
|
|
|
checkpoint_path=checkpoint_path,
|
|
|
|
model='roformer_v2',
|
|
|
|
return_keras_model=False
|
|
|
|
)
|
|
|
|
|
|
|
|
output = Lambda(lambda x: x[:, 0])(roformer.model.output)
|
|
|
|
|
|
|
|
output = Dense(
|
|
|
|
units=class_nums,
|
|
|
|
kernel_initializer=roformer.initializer
|
|
|
|
)(output)
|
|
|
|
|
|
|
|
model = Model(roformer.model.input, output)
|
|
|
|
model.load_weights(weight_path)
|
|
|
|
model.summary()
|
|
|
|
|
|
|
|
|
|
|
|
def load_label1():
|
|
|
|
with open(config_lable, 'r',
|
|
|
|
encoding='utf-8') as f:
|
|
|
|
labels_dict = json.load(f)
|
|
|
|
|
|
|
|
id2label1 = {j[0]: i for i, j in labels_dict.items()}
|
|
|
|
label2id1 = {i: j[0] for i, j in labels_dict.items()}
|
|
|
|
label_threshold1 = np.array([j[1] for i, j in labels_dict.items()])
|
|
|
|
|
|
|
|
return id2label1, label2id1, label_threshold1
|
|
|
|
|
|
|
|
id2label, label2id, label_threshold = load_label1()
|
|
|
|
|
|
|
|
def predict(text):
|
|
|
|
text = text[0]
|
|
|
|
sent_token_id, sent_segment_id = [], []
|
|
|
|
token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len)
|
|
|
|
y_pred = model.predict([[token_ids], [segment_ids]])
|
|
|
|
# print(len(y_pred[0]))
|
|
|
|
label_probability = {}
|
|
|
|
for index, threshold in zip(range(len(y_pred[0])), label_threshold):
|
|
|
|
label_probability[index] = y_pred[0][index] - threshold
|
|
|
|
label_probability = sorted(label_probability.items(),
|
|
|
|
key=lambda x: x[1], reverse=True)
|
|
|
|
label_pre = []
|
|
|
|
for i in label_probability:
|
|
|
|
if i[1] > 0:
|
|
|
|
label_pre.append(id2label[i[0]])
|
|
|
|
if label_pre == []:
|
|
|
|
label_pre.append(id2label[label_probability[0][0]])
|
|
|
|
return label_pre
|
|
|
|
# idx = np.where(y_pred[0] > label_threshold, 1, 0)
|
|
|
|
# label_pre = []
|
|
|
|
# for i in range(len(idx)):
|
|
|
|
# if idx[i] == 1:
|
|
|
|
# label_pre.append(id2label[i])
|
|
|
|
# return label_pre
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# text_list = ["你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你"]
|
|
|
|
# y_pred = predict(text_list)
|
|
|
|
# idx = np.where(y_pred[0] > label_threshold, 1, 0)
|
|
|
|
# label_pre = []
|
|
|
|
# for i in range(len(idx)):
|
|
|
|
# if idx[i] == 1:
|
|
|
|
# label_pre.append(id2label[i])
|
|
|
|
# print(label_pre)
|
|
|
|
|
|
|
|
#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
data = pd.read_csv("data/yy改写相似度.csv").values.tolist()
|
|
|
|
|
|
|
|
data_new = []
|
|
|
|
for data_dan in tqdm(data):
|
|
|
|
label_pre = predict([data_dan[0]])
|
|
|
|
label_pre = ",".join(label_pre)
|
|
|
|
data_new.append(data_dan + [label_pre])
|
|
|
|
df = pd.DataFrame(data_new)
|
|
|
|
df.to_csv("./data/yy改写相似度含文章类别.csv", index=None)
|