You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

65 lines
2.5 KiB

# -*- coding: utf-8 -*-
"""
@Time : 2023/3/13 10:15
@Author :
@FileName:
@Software:
@Describe:
"""
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
from bert4keras.layers import ConditionalRandomField
from keras.layers import Dense
from keras.models import Model
from tqdm import tqdm
import json
from keras.layers import *
class ClassifyModel:
def __init__(self, config_path, checkpoint_path, dict_path, is_train, load_weights_path=None):
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.dict_path = dict_path
self.is_train = True
self.load_weights_path = load_weights_path
self.model = self.create_model(self.is_train, self.load_weights_path)
self.tokenizer = Tokenizer(self.dict_path, do_lower_case=True)
self.maxlen = 256
def create_model(self, is_train, load_weights_path):
bert = build_transformer_model(
config_path=self.config_path,
checkpoint_path=self.checkpoint_path,
return_keras_model=False,
)
output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
model = keras.models.Model(bert.model.input, output)
if is_train == False:
model.load_weights(load_weights_path)
return model
def predict(self,token_ids, segment_ids):
return self.model.predict([token_ids, segment_ids])
def data_generator(self, texts, batch_size):
batch_token_ids = []
batch_segment_ids = []
batch_dan_token_ids = []
batch_dan_segment_ids = []
for id_, text in enumerate(texts):
token_ids, segment_ids = self.tokenizer.encode(text, maxlen=self.maxlen)
batch_dan_token_ids.append(token_ids)
batch_dan_segment_ids.append(segment_ids)
if len(batch_dan_token_ids) == batch_size or id_ == len(texts)-1:
batch_dan_token_ids = sequence_padding(batch_dan_token_ids)
batch_dan_segment_ids = sequence_padding(batch_dan_segment_ids)
batch_token_ids.append(batch_dan_token_ids)
batch_segment_ids.append(batch_dan_segment_ids)
batch_dan_token_ids, batch_dan_segment_ids = [], []
return batch_token_ids, batch_segment_ids