You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.7 KiB

# -*- coding:utf-8 -*-
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import random
import keras
import numpy as np
import pandas as pd
from bert4keras.backend import multilabel_categorical_crossentropy
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam
from bert4keras.snippets import DataGenerator, sequence_padding
from keras.layers import Lambda, Dense
from keras.models import Model
from bert4keras.tokenizers import Tokenizer
from tqdm import tqdm
import tensorflow as tf
from keras.backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) # 此处不同
config_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt'
class_nums = 168
batch_size = 16
max_len = 512
config_lable = './config_json/label_threshold.json'
weight_path = './output_models/best_model.weights'
tokenizer = Tokenizer(token_dict=dict_path)
roformer = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
model='roformer_v2',
return_keras_model=False
)
output = Lambda(lambda x: x[:, 0])(roformer.model.output)
output = Dense(
units=class_nums,
kernel_initializer=roformer.initializer
)(output)
model = Model(roformer.model.input, output)
model.load_weights(weight_path)
model.summary()
def load_label1():
with open(config_lable, 'r',
encoding='utf-8') as f:
labels_dict = json.load(f)
id2label1 = {j[0]: i for i, j in labels_dict.items()}
label2id1 = {i: j[0] for i, j in labels_dict.items()}
label_threshold1 = np.array([j[1] for i, j in labels_dict.items()])
return id2label1, label2id1, label_threshold1
id2label, label2id, label_threshold = load_label1()
def predict(text):
text = text[0]
sent_token_id, sent_segment_id = [], []
token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len)
y_pred = model.predict([[token_ids], [segment_ids]])
# print(len(y_pred[0]))
label_probability = {}
for index, threshold in zip(range(len(y_pred[0])), label_threshold):
label_probability[index] = y_pred[0][index] - threshold
label_probability = sorted(label_probability.items(),
key=lambda x: x[1], reverse=True)
label_pre = []
for i in label_probability:
if i[1] > 0:
label_pre.append(id2label[i[0]])
if label_pre == []:
label_pre.append(id2label[label_probability[0][0]])
return label_pre
# idx = np.where(y_pred[0] > label_threshold, 1, 0)
# label_pre = []
# for i in range(len(idx)):
# if idx[i] == 1:
# label_pre.append(id2label[i])
# return label_pre
if __name__ == '__main__':
# text_list = ["你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你"]
# y_pred = predict(text_list)
# idx = np.where(y_pred[0] > label_threshold, 1, 0)
# label_pre = []
# for i in range(len(idx)):
# if idx[i] == 1:
# label_pre.append(id2label[i])
# print(label_pre)
#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
data = pd.read_csv("data/yy改写相似度.csv").values.tolist()
data_new = []
for data_dan in tqdm(data):
label_pre = predict([data_dan[0]])
label_pre = "".join(label_pre)
data_new.append(data_dan + [label_pre])
df = pd.DataFrame(data_new)
df.to_csv("./data/yy改写相似度含文章类别.csv", index=None)