# -*- coding: utf-8 -*- """ @Time : 2023/3/13 10:15 @Author : @FileName: @Software: @Describe: """ from bert4keras.backend import keras, K from bert4keras.models import build_transformer_model from bert4keras.tokenizers import Tokenizer from bert4keras.optimizers import Adam from bert4keras.snippets import sequence_padding, DataGenerator from bert4keras.snippets import open from bert4keras.layers import ConditionalRandomField from keras.layers import Dense from keras.models import Model from tqdm import tqdm import json from keras.layers import * class ClassifyModel: def __init__(self, config_path, checkpoint_path, dict_path, is_train, load_weights_path=None): self.config_path = config_path self.checkpoint_path = checkpoint_path self.dict_path = dict_path self.is_train = True self.load_weights_path = load_weights_path self.model = self.create_model(self.is_train, self.load_weights_path) self.tokenizer = Tokenizer(self.dict_path, do_lower_case=True) self.maxlen = 256 def create_model(self, is_train, load_weights_path): bert = build_transformer_model( config_path=self.config_path, checkpoint_path=self.checkpoint_path, return_keras_model=False, ) output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) model = keras.models.Model(bert.model.input, output) if is_train == False: model.load_weights(load_weights_path) return model def predict(self,token_ids, segment_ids): return self.model.predict([token_ids, segment_ids]) def data_generator(self, texts, batch_size): batch_token_ids = [] batch_segment_ids = [] batch_dan_token_ids = [] batch_dan_segment_ids = [] for id_, text in enumerate(texts): token_ids, segment_ids = self.tokenizer.encode(text, maxlen=self.maxlen) batch_dan_token_ids.append(token_ids) batch_dan_segment_ids.append(segment_ids) if len(batch_dan_token_ids) == batch_size or id_ == len(texts)-1: batch_dan_token_ids = sequence_padding(batch_dan_token_ids) batch_dan_segment_ids = sequence_padding(batch_dan_segment_ids) batch_token_ids.append(batch_dan_token_ids) batch_segment_ids.append(batch_dan_segment_ids) batch_dan_token_ids, batch_dan_segment_ids = [], [] return batch_token_ids, batch_segment_ids