65 lines
2.5 KiB
65 lines
2.5 KiB
![]()
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
"""
|
||
|
@Time : 2023/3/13 10:15
|
||
|
@Author :
|
||
|
@FileName:
|
||
|
@Software:
|
||
|
@Describe:
|
||
|
"""
|
||
|
from bert4keras.backend import keras, K
|
||
|
from bert4keras.models import build_transformer_model
|
||
|
from bert4keras.tokenizers import Tokenizer
|
||
|
from bert4keras.optimizers import Adam
|
||
|
from bert4keras.snippets import sequence_padding, DataGenerator
|
||
|
from bert4keras.snippets import open
|
||
|
from bert4keras.layers import ConditionalRandomField
|
||
|
from keras.layers import Dense
|
||
|
from keras.models import Model
|
||
|
from tqdm import tqdm
|
||
|
import json
|
||
|
from keras.layers import *
|
||
|
|
||
|
|
||
|
class ClassifyModel:
|
||
|
def __init__(self, config_path, checkpoint_path, dict_path, is_train, load_weights_path=None):
|
||
|
self.config_path = config_path
|
||
|
self.checkpoint_path = checkpoint_path
|
||
|
self.dict_path = dict_path
|
||
|
self.is_train = True
|
||
|
self.load_weights_path = load_weights_path
|
||
|
self.model = self.create_model(self.is_train, self.load_weights_path)
|
||
|
self.tokenizer = Tokenizer(self.dict_path, do_lower_case=True)
|
||
|
self.maxlen = 256
|
||
|
|
||
|
def create_model(self, is_train, load_weights_path):
|
||
|
bert = build_transformer_model(
|
||
|
config_path=self.config_path,
|
||
|
checkpoint_path=self.checkpoint_path,
|
||
|
return_keras_model=False,
|
||
|
)
|
||
|
output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
|
||
|
model = keras.models.Model(bert.model.input, output)
|
||
|
if is_train == False:
|
||
|
model.load_weights(load_weights_path)
|
||
|
return model
|
||
|
|
||
|
def predict(self,token_ids, segment_ids):
|
||
|
return self.model.predict([token_ids, segment_ids])
|
||
|
|
||
|
def data_generator(self, texts, batch_size):
|
||
|
batch_token_ids = []
|
||
|
batch_segment_ids = []
|
||
|
batch_dan_token_ids = []
|
||
|
batch_dan_segment_ids = []
|
||
|
for id_, text in enumerate(texts):
|
||
|
token_ids, segment_ids = self.tokenizer.encode(text, maxlen=self.maxlen)
|
||
|
batch_dan_token_ids.append(token_ids)
|
||
|
batch_dan_segment_ids.append(segment_ids)
|
||
|
if len(batch_dan_token_ids) == batch_size or id_ == len(texts)-1:
|
||
|
batch_dan_token_ids = sequence_padding(batch_dan_token_ids)
|
||
|
batch_dan_segment_ids = sequence_padding(batch_dan_segment_ids)
|
||
|
batch_token_ids.append(batch_dan_token_ids)
|
||
|
batch_segment_ids.append(batch_dan_segment_ids)
|
||
|
batch_dan_token_ids, batch_dan_segment_ids = [], []
|
||
|
return batch_token_ids, batch_segment_ids
|