You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
3.1 KiB
105 lines
3.1 KiB
![]()
2 years ago
|
# -*- coding:utf-8 -*-
|
||
|
|
||
|
import os
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||
|
import json
|
||
|
import random
|
||
|
import keras
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
from bert4keras.backend import multilabel_categorical_crossentropy
|
||
|
from bert4keras.models import build_transformer_model
|
||
|
from bert4keras.optimizers import Adam
|
||
|
from bert4keras.snippets import DataGenerator, sequence_padding
|
||
|
from keras.layers import Lambda, Dense
|
||
|
from keras.models import Model
|
||
|
from bert4keras.tokenizers import Tokenizer
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from keras.backend import set_session
|
||
|
config = tf.ConfigProto()
|
||
|
config.gpu_options.allow_growth = True
|
||
|
set_session(tf.Session(config=config)) # 此处不同
|
||
|
|
||
|
config_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json'
|
||
|
checkpoint_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt'
|
||
|
dict_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt'
|
||
|
|
||
|
|
||
|
class_nums = 168
|
||
|
batch_size = 16
|
||
|
max_len = 512
|
||
|
|
||
|
config_lable = './config_json/label_threshold.json'
|
||
|
weight_path = './output_models/best_model.weights'
|
||
|
|
||
|
|
||
|
tokenizer = Tokenizer(token_dict=dict_path)
|
||
|
|
||
|
roformer = build_transformer_model(
|
||
|
config_path=config_path,
|
||
|
checkpoint_path=checkpoint_path,
|
||
|
model='roformer_v2',
|
||
|
return_keras_model=False
|
||
|
)
|
||
|
|
||
|
output = Lambda(lambda x: x[:, 0])(roformer.model.output)
|
||
|
|
||
|
output = Dense(
|
||
|
units=class_nums,
|
||
|
kernel_initializer=roformer.initializer
|
||
|
)(output)
|
||
|
|
||
|
model = Model(roformer.model.input, output)
|
||
|
model.load_weights(weight_path)
|
||
|
model.summary()
|
||
|
|
||
|
|
||
|
def load_label1():
|
||
|
with open(config_lable, 'r',
|
||
|
encoding='utf-8') as f:
|
||
|
labels_dict = json.load(f)
|
||
|
|
||
|
id2label1 = {j[0]: i for i, j in labels_dict.items()}
|
||
|
label2id1 = {i: j[0] for i, j in labels_dict.items()}
|
||
|
label_threshold1 = np.array([j[1] for i, j in labels_dict.items()])
|
||
|
|
||
|
return id2label1, label2id1, label_threshold1
|
||
|
|
||
|
id2label, label2id, label_threshold = load_label1()
|
||
|
|
||
|
def predict(text):
|
||
|
text = text[0]
|
||
|
sent_token_id, sent_segment_id = [], []
|
||
|
token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len)
|
||
|
y_pred = model.predict([[token_ids], [segment_ids]])
|
||
|
idx = np.where(y_pred[0] > label_threshold, 1, 0)
|
||
|
label_pre = []
|
||
|
for i in range(len(idx)):
|
||
|
if idx[i] == 1:
|
||
|
label_pre.append(id2label[i])
|
||
|
return label_pre
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# text_list = ["你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你"]
|
||
|
# y_pred = predict(text_list)
|
||
|
# idx = np.where(y_pred[0] > label_threshold, 1, 0)
|
||
|
# label_pre = []
|
||
|
# for i in range(len(idx)):
|
||
|
# if idx[i] == 1:
|
||
|
# label_pre.append(id2label[i])
|
||
|
# print(label_pre)
|
||
|
|
||
|
#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
data = pd.read_csv("data/yy改写相似度.csv").values.tolist()
|
||
|
|
||
|
data_new = []
|
||
|
for data_dan in tqdm(data):
|
||
|
label_pre = predict([data_dan[0]])
|
||
|
label_pre = ",".join(label_pre)
|
||
|
data_new.append(data_dan + [label_pre])
|
||
|
df = pd.DataFrame(data_new)
|
||
|
print(df)
|
||
|
df.to_csv("./data/yy改写相似度含文章类别.csv", index=None)
|