You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
2.5 KiB
65 lines
2.5 KiB
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@Time : 2023/3/13 10:15
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
from bert4keras.backend import keras, K
|
|
from bert4keras.models import build_transformer_model
|
|
from bert4keras.tokenizers import Tokenizer
|
|
from bert4keras.optimizers import Adam
|
|
from bert4keras.snippets import sequence_padding, DataGenerator
|
|
from bert4keras.snippets import open
|
|
from bert4keras.layers import ConditionalRandomField
|
|
from keras.layers import Dense
|
|
from keras.models import Model
|
|
from tqdm import tqdm
|
|
import json
|
|
from keras.layers import *
|
|
|
|
|
|
class ClassifyModel:
|
|
def __init__(self, config_path, checkpoint_path, dict_path, is_train, load_weights_path=None):
|
|
self.config_path = config_path
|
|
self.checkpoint_path = checkpoint_path
|
|
self.dict_path = dict_path
|
|
self.is_train = True
|
|
self.load_weights_path = load_weights_path
|
|
self.model = self.create_model(self.is_train, self.load_weights_path)
|
|
self.tokenizer = Tokenizer(self.dict_path, do_lower_case=True)
|
|
self.maxlen = 256
|
|
|
|
def create_model(self, is_train, load_weights_path):
|
|
bert = build_transformer_model(
|
|
config_path=self.config_path,
|
|
checkpoint_path=self.checkpoint_path,
|
|
return_keras_model=False,
|
|
)
|
|
output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
|
|
model = keras.models.Model(bert.model.input, output)
|
|
if is_train == False:
|
|
model.load_weights(load_weights_path)
|
|
return model
|
|
|
|
def predict(self,token_ids, segment_ids):
|
|
return self.model.predict([token_ids, segment_ids])
|
|
|
|
def data_generator(self, texts, batch_size):
|
|
batch_token_ids = []
|
|
batch_segment_ids = []
|
|
batch_dan_token_ids = []
|
|
batch_dan_segment_ids = []
|
|
for id_, text in enumerate(texts):
|
|
token_ids, segment_ids = self.tokenizer.encode(text, maxlen=self.maxlen)
|
|
batch_dan_token_ids.append(token_ids)
|
|
batch_dan_segment_ids.append(segment_ids)
|
|
if len(batch_dan_token_ids) == batch_size or id_ == len(texts)-1:
|
|
batch_dan_token_ids = sequence_padding(batch_dan_token_ids)
|
|
batch_dan_segment_ids = sequence_padding(batch_dan_segment_ids)
|
|
batch_token_ids.append(batch_dan_token_ids)
|
|
batch_segment_ids.append(batch_dan_segment_ids)
|
|
batch_dan_token_ids, batch_dan_segment_ids = [], []
|
|
return batch_token_ids, batch_segment_ids
|