# -*- coding:utf-8 -*- import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" import json import random import keras import numpy as np import pandas as pd from bert4keras.backend import multilabel_categorical_crossentropy from bert4keras.models import build_transformer_model from bert4keras.optimizers import Adam from bert4keras.snippets import DataGenerator, sequence_padding from keras.layers import Lambda, Dense from keras.models import Model from bert4keras.tokenizers import Tokenizer from tqdm import tqdm import tensorflow as tf from keras.backend import set_session config = tf.ConfigProto() config.gpu_options.allow_growth = True set_session(tf.Session(config=config)) # 此处不同 config_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json' checkpoint_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt' dict_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt' class_nums = 168 batch_size = 16 max_len = 512 config_lable = './config_json/label_threshold.json' weight_path = './output_models/best_model.weights' tokenizer = Tokenizer(token_dict=dict_path) roformer = build_transformer_model( config_path=config_path, checkpoint_path=checkpoint_path, model='roformer_v2', return_keras_model=False ) output = Lambda(lambda x: x[:, 0])(roformer.model.output) output = Dense( units=class_nums, kernel_initializer=roformer.initializer )(output) model = Model(roformer.model.input, output) model.load_weights(weight_path) model.summary() def load_label1(): with open(config_lable, 'r', encoding='utf-8') as f: labels_dict = json.load(f) id2label1 = {j[0]: i for i, j in labels_dict.items()} label2id1 = {i: j[0] for i, j in labels_dict.items()} label_threshold1 = np.array([j[1] for i, j in labels_dict.items()]) return id2label1, label2id1, label_threshold1 id2label, label2id, label_threshold = load_label1() def predict(text): text = text[0] sent_token_id, sent_segment_id = [], [] token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len) y_pred = model.predict([[token_ids], [segment_ids]]) idx = np.where(y_pred[0] > label_threshold, 1, 0) label_pre = [] for i in range(len(idx)): if idx[i] == 1: label_pre.append(id2label[i]) return label_pre if __name__ == '__main__': # text_list = ["你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你"] # y_pred = predict(text_list) # idx = np.where(y_pred[0] > label_threshold, 1, 0) # label_pre = [] # for i in range(len(idx)): # if idx[i] == 1: # label_pre.append(id2label[i]) # print(label_pre) #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ data = pd.read_csv("data/yy改写相似度.csv").values.tolist() data_new = [] for data_dan in tqdm(data): label_pre = predict([data_dan[0]]) label_pre = ",".join(label_pre) data_new.append(data_dan + [label_pre]) df = pd.DataFrame(data_new) print(df) df.to_csv("./data/yy改写相似度含文章类别.csv", index=None)